diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 2b3927e7c9..10a344f768 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -401,10 +401,7 @@ def _get_info_schema_columns_query( return query, folded_table_names def _get_column_def_sql(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(column["name"]) - column_def_sql = ( - f"{name} {self.type_mapper.to_destination_type(column, table)} {self._gen_not_null(column.get('nullable', True))}" - ) + column_def_sql = super()._get_column_def_sql(column, table) if column.get(ROUND_HALF_EVEN_HINT, False): column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')" if column.get(ROUND_HALF_AWAY_FROM_ZERO_HINT, False): diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 3a5f5c3e28..a407e56361 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -292,11 +292,10 @@ def _get_table_update_sql( return sql - @staticmethod - def _gen_not_null(v: bool) -> str: + def _gen_not_null(self, v: bool) -> str: # ClickHouse fields are not nullable by default. # We use the `Nullable` modifier instead of NULL / NOT NULL modifiers to cater for ALTER statement. - pass + return "" def _from_db_type( self, ch_t: str, precision: Optional[int], scale: Optional[int] diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 2bb68a607e..a83db6ec34 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -264,12 +264,6 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" - ) - def _get_storage_table_query_columns(self) -> List[str]: fields = super()._get_storage_table_query_columns() fields[2] = ( # Override because this is the only way to get data type with precision diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index ab23f58ab4..e3a090c824 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -151,12 +151,6 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" - ) - def _create_merge_followup_jobs( self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 3bd4c83e1f..2b3370270b 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -74,17 +74,6 @@ def create_load_job( job = DuckDbCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_str = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) - column_name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" - ) - def _from_db_type( self, pq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index 27aebe07f2..7b48a6b551 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -115,11 +115,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = Non else: db_type = self.type_mapper.to_destination_type(c, table) - hints_str = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) + hints_str = self._get_column_hints_sql(c) column_name = self.sql_client.escape_column_name(c["name"]) return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 2459ee1dbe..3d54b59f93 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -161,18 +161,6 @@ def create_load_job( job = PostgresCsvCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_ = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) - column_name = self.sql_client.escape_column_name(c["name"]) - nullability = self._gen_not_null(c.get("nullable", True)) - column_type = self.type_mapper.to_destination_type(c, table) - - return f"{column_name} {column_type} {hints_} {nullability}" - def _create_replace_followup_jobs( self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 2335166761..b1aa37ce6a 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -153,6 +153,7 @@ def __init__( capabilities, ) super().__init__(schema, config, sql_client) + self.active_hints = HINT_TO_REDSHIFT_ATTR self.sql_client = sql_client self.config: RedshiftClientConfiguration = config self.type_mapper = self.capabilities.get_type_mapper() @@ -162,17 +163,6 @@ def _create_merge_followup_jobs( ) -> List[FollowupJobRequest]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_str = " ".join( - HINT_TO_REDSHIFT_ATTR.get(h, "") - for h in HINT_TO_REDSHIFT_ATTR.keys() - if c.get(h, False) is True - ) - column_name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" - ) - def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 4a89a1564b..2e589ea095 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -138,6 +138,24 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration) query_tag: Optional[str] = None """A tag with placeholders to tag sessions executing jobs""" + create_indexes: bool = False + """Whether UNIQUE or PRIMARY KEY constrains should be created""" + + def __init__( + self, + *, + credentials: SnowflakeCredentials = None, + create_indexes: bool = False, + destination_name: str = None, + environment: str = None, + ) -> None: + super().__init__( + credentials=credentials, + destination_name=destination_name, + environment=environment, + ) + self.create_indexes = create_indexes + def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" if self.credentials and self.credentials.host: diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index e5146139f2..786cdc0b77 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -1,6 +1,7 @@ -from typing import Optional, Sequence, List +from typing import Optional, Sequence, List, Dict, Set from urllib.parse import urlparse, urlunparse +from dlt.common import logger from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( @@ -15,13 +16,15 @@ AwsCredentialsWithoutDefaults, AzureCredentialsWithoutDefaults, ) +from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema import TColumnSchema, Schema -from dlt.common.schema.typing import TColumnType +from dlt.common.schema import TColumnSchema, Schema, TColumnHint +from dlt.common.schema.typing import TColumnType, TTableSchema from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS from dlt.common.typing import TLoaderFileFormat +from dlt.common.utils import uniq_id from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.exceptions import LoadJobTerminalException @@ -29,6 +32,8 @@ from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import ReferenceFollowupJobRequest +SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE"} + class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( @@ -238,6 +243,7 @@ def __init__( self.config: SnowflakeClientConfiguration = config self.sql_client: SnowflakeSqlClient = sql_client # type: ignore self.type_mapper = self.capabilities.get_type_mapper() + self.active_hints = SUPPORTED_HINTS if self.config.create_indexes else {} def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False @@ -264,6 +270,33 @@ def _make_add_column_sql( "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) ] + def _get_constraints_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> str: + # "primary_key": "PRIMARY KEY" + if self.config.create_indexes: + partial: TTableSchema = { + "name": table_name, + "columns": {c["name"]: c for c in new_columns}, + } + # Add PK constraint if pk_columns exist + pk_columns = get_columns_names_with_prop(partial, "primary_key") + if pk_columns: + if generate_alter: + logger.warning( + f"PRIMARY KEY on {table_name} constraint cannot be added in ALTER TABLE and" + " is ignored" + ) + else: + pk_constraint_name = list( + self._norm_and_escape_columns(f"PK_{table_name}_{uniq_id(4)}") + )[0] + quoted_pk_cols = ", ".join( + self.sql_client.escape_column_name(col) for col in pk_columns + ) + return f",\nCONSTRAINT {pk_constraint_name} PRIMARY KEY ({quoted_pk_cols})" + return "" + def _get_table_update_sql( self, table_name: str, @@ -287,11 +320,5 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" - ) - def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index d1f211b1e9..888c80c006 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -7,6 +7,7 @@ from typing import ( Any, ClassVar, + Dict, List, Optional, Sequence, @@ -14,21 +15,18 @@ Type, Iterable, Iterator, - Generator, ) import zlib import re -from contextlib import contextmanager -from contextlib import suppress from dlt.common import pendulum, logger +from dlt.common.destination.capabilities import DataTypeMapper from dlt.common.json import json from dlt.common.schema.typing import ( C_DLT_LOAD_ID, COLUMN_HINTS, TColumnType, TColumnSchemaBase, - TTableFormat, ) from dlt.common.schema.utils import ( get_inherited_table_hint, @@ -40,11 +38,11 @@ from dlt.common.storages import FileStorage from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables +from dlt.common.schema import TColumnHint from dlt.common.destination.reference import ( PreparedTableSchema, StateInfo, StorageSchemaInfo, - SupportsReadableDataset, WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, @@ -55,9 +53,7 @@ JobClientBase, HasFollowupJobs, CredentialsConfiguration, - SupportsReadableRelation, ) -from dlt.destinations.dataset import ReadableDBAPIDataset from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations.job_impl import ( @@ -154,6 +150,8 @@ def __init__( self.state_table_columns = ", ".join( sql_client.escape_column_name(col) for col in state_table_["columns"] ) + self.active_hints: Dict[TColumnHint, str] = {} + self.type_mapper: DataTypeMapper = None super().__init__(schema, config, sql_client.capabilities) self.sql_client = sql_client assert isinstance(config, DestinationClientDwhConfiguration) @@ -569,6 +567,7 @@ def _get_table_update_sql( # build CREATE sql = self._make_create_table(qualified_name, table) + " (\n" sql += ",\n".join([self._get_column_def_sql(c, table) for c in new_columns]) + sql += self._get_constraints_sql(table_name, new_columns, generate_alter) sql += ")" sql_result.append(sql) else: @@ -582,8 +581,16 @@ def _get_table_update_sql( sql_result.extend( [sql_base + col_statement for col_statement in add_column_statements] ) + constraints_sql = self._get_constraints_sql(table_name, new_columns, generate_alter) + if constraints_sql: + sql_result.append(constraints_sql) return sql_result + def _get_constraints_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> str: + return "" + def _check_table_update_hints( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> None: @@ -613,12 +620,22 @@ def _check_table_update_hints( " existing tables." ) - @abstractmethod def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - pass + hints_ = self._get_column_hints_sql(c) + column_name = self.sql_client.escape_column_name(c["name"]) + nullability = self._gen_not_null(c.get("nullable", True)) + column_type = self.type_mapper.to_destination_type(c, table) + + return f"{column_name} {column_type} {hints_} {nullability}" + + def _get_column_hints_sql(self, c: TColumnSchema) -> str: + return " ".join( + self.active_hints.get(h, "") + for h in self.active_hints.keys() + if c.get(h, False) is True # use ColumnPropInfos to get default value + ) - @staticmethod - def _gen_not_null(nullable: bool) -> str: + def _gen_not_null(self, nullable: bool) -> str: return "NOT NULL" if not nullable else "" def _create_table_update( diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index 07cf822973..28684c39ac 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -200,6 +200,12 @@ Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and ## Supported column hints Snowflake supports the following [column hints](../../general-usage/schema#tables-and-columns): * `cluster` - Creates a cluster column(s). Many columns per table are supported and only when a new table is created. +* `unique` - Creates UNIQUE hint on a Snowflake column, can be added to many columns. ([optional](#additional-destination-options)) +* `primary_key` - Creates PRIMARY KEY on selected column(s), may be compound. ([optional](#additional-destination-options)) + +`unique` and `primary_key` are not enforced and `dlt` does not instruct Snowflake to `RELY` on them when +query planning. + ## Table and column identifiers Snowflake supports both case-sensitive and case-insensitive identifiers. All unquoted and uppercase identifiers resolve case-insensitively in SQL statements. Case-insensitive [naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) like the default **snake_case** will generate case-insensitive identifiers. Case-sensitive (like **sql_cs_v1**) will generate @@ -308,6 +314,7 @@ pipeline = dlt.pipeline( ## Additional destination options You can define your own stage to PUT files and disable the removal of the staged files after loading. +You can also opt-in to [create indexes](#supported-column-hints). ```toml [destination.snowflake] @@ -315,6 +322,8 @@ You can define your own stage to PUT files and disable the removal of the staged stage_name="DLT_STAGE" # Whether to keep or delete the staged files after COPY INTO succeeds keep_staged_files=true +# Add UNIQUE and PRIMARY KEY hints to tables +create_indexes=true ``` ### Setting up CSV format diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 56a674cfa3..b2857b7c08 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -107,25 +107,25 @@ def test_create_table(gcp_client: BigQueryClient) -> None: sqlfluff.parse(sql, dialect="bigquery") assert sql.startswith("CREATE TABLE") assert "event_test_table" in sql - assert "`col1` INT64 NOT NULL" in sql - assert "`col2` FLOAT64 NOT NULL" in sql - assert "`col3` BOOL NOT NULL" in sql - assert "`col4` TIMESTAMP NOT NULL" in sql + assert "`col1` INT64 NOT NULL" in sql + assert "`col2` FLOAT64 NOT NULL" in sql + assert "`col3` BOOL NOT NULL" in sql + assert "`col4` TIMESTAMP NOT NULL" in sql assert "`col5` STRING " in sql - assert "`col6` NUMERIC(38,9) NOT NULL" in sql + assert "`col6` NUMERIC(38,9) NOT NULL" in sql assert "`col7` BYTES" in sql assert "`col8` BIGNUMERIC" in sql - assert "`col9` JSON NOT NULL" in sql + assert "`col9` JSON NOT NULL" in sql assert "`col10` DATE" in sql assert "`col11` TIME" in sql - assert "`col1_precision` INT64 NOT NULL" in sql - assert "`col4_precision` TIMESTAMP NOT NULL" in sql + assert "`col1_precision` INT64 NOT NULL" in sql + assert "`col4_precision` TIMESTAMP NOT NULL" in sql assert "`col5_precision` STRING(25) " in sql - assert "`col6_precision` NUMERIC(6,2) NOT NULL" in sql + assert "`col6_precision` NUMERIC(6,2) NOT NULL" in sql assert "`col7_precision` BYTES(19)" in sql - assert "`col11_precision` TIME NOT NULL" in sql - assert "`col_high_p_decimal` BIGNUMERIC(76,0) NOT NULL" in sql - assert "`col_high_s_decimal` BIGNUMERIC(38,24) NOT NULL" in sql + assert "`col11_precision` TIME NOT NULL" in sql + assert "`col_high_p_decimal` BIGNUMERIC(76,0) NOT NULL" in sql + assert "`col_high_s_decimal` BIGNUMERIC(38,24) NOT NULL" in sql assert "CLUSTER BY" not in sql assert "PARTITION BY" not in sql @@ -137,29 +137,29 @@ def test_alter_table(gcp_client: BigQueryClient) -> None: assert sql.startswith("ALTER TABLE") assert sql.count("ALTER TABLE") == 1 assert "event_test_table" in sql - assert "ADD COLUMN `col1` INT64 NOT NULL" in sql - assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql - assert "ADD COLUMN `col3` BOOL NOT NULL" in sql - assert "ADD COLUMN `col4` TIMESTAMP NOT NULL" in sql + assert "ADD COLUMN `col1` INT64 NOT NULL" in sql + assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql + assert "ADD COLUMN `col3` BOOL NOT NULL" in sql + assert "ADD COLUMN `col4` TIMESTAMP NOT NULL" in sql assert "ADD COLUMN `col5` STRING" in sql - assert "ADD COLUMN `col6` NUMERIC(38,9) NOT NULL" in sql + assert "ADD COLUMN `col6` NUMERIC(38,9) NOT NULL" in sql assert "ADD COLUMN `col7` BYTES" in sql assert "ADD COLUMN `col8` BIGNUMERIC" in sql - assert "ADD COLUMN `col9` JSON NOT NULL" in sql + assert "ADD COLUMN `col9` JSON NOT NULL" in sql assert "ADD COLUMN `col10` DATE" in sql assert "ADD COLUMN `col11` TIME" in sql - assert "ADD COLUMN `col1_precision` INT64 NOT NULL" in sql - assert "ADD COLUMN `col4_precision` TIMESTAMP NOT NULL" in sql + assert "ADD COLUMN `col1_precision` INT64 NOT NULL" in sql + assert "ADD COLUMN `col4_precision` TIMESTAMP NOT NULL" in sql assert "ADD COLUMN `col5_precision` STRING(25)" in sql - assert "ADD COLUMN `col6_precision` NUMERIC(6,2) NOT NULL" in sql + assert "ADD COLUMN `col6_precision` NUMERIC(6,2) NOT NULL" in sql assert "ADD COLUMN `col7_precision` BYTES(19)" in sql - assert "ADD COLUMN `col11_precision` TIME NOT NULL" in sql + assert "ADD COLUMN `col11_precision` TIME NOT NULL" in sql # table has col1 already in storage mod_table = deepcopy(TABLE_UPDATE) mod_table.pop(0) sql = gcp_client._get_table_update_sql("event_test_table", mod_table, True)[0] - assert "ADD COLUMN `col1` INTEGER NOT NULL" not in sql - assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql + assert "ADD COLUMN `col1` INTEGER NOT NULL" not in sql + assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql def test_create_table_case_insensitive(ci_gcp_client: BigQueryClient) -> None: diff --git a/tests/load/dremio/test_dremio_client.py b/tests/load/dremio/test_dremio_client.py index efc72c0652..98212efb13 100644 --- a/tests/load/dremio/test_dremio_client.py +++ b/tests/load/dremio/test_dremio_client.py @@ -48,12 +48,12 @@ def test_dremio_factory() -> None: [ TColumnSchema(name="foo", data_type="text", partition=True), TColumnSchema(name="bar", data_type="bigint", sort=True), - TColumnSchema(name="baz", data_type="double"), + TColumnSchema(name="baz", data_type="double", nullable=False), ], False, [ 'CREATE TABLE "test_database"."test_dataset"."event_test_table"' - ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )\nPARTITION BY' + ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE NOT NULL)\nPARTITION BY' ' ("foo")\nLOCALSORT BY ("bar")' ], ), @@ -66,7 +66,7 @@ def test_dremio_factory() -> None: False, [ 'CREATE TABLE "test_database"."test_dataset"."event_test_table"' - ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )\nPARTITION BY' + ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )\nPARTITION BY' ' ("foo","bar")' ], ), @@ -79,7 +79,7 @@ def test_dremio_factory() -> None: False, [ 'CREATE TABLE "test_database"."test_dataset"."event_test_table"' - ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )' + ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )' ], ), ], diff --git a/tests/load/snowflake/test_snowflake_client.py b/tests/load/snowflake/test_snowflake_client.py index aebf514b56..674e01ba31 100644 --- a/tests/load/snowflake/test_snowflake_client.py +++ b/tests/load/snowflake/test_snowflake_client.py @@ -1,14 +1,17 @@ +from copy import deepcopy import os from typing import Iterator from pytest_mock import MockerFixture import pytest -from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient +from dlt.common.schema.schema import Schema +from dlt.destinations.impl.snowflake.snowflake import SUPPORTED_HINTS, SnowflakeClient from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.destinations.sql_client import TJobQueryTags -from tests.load.utils import yield_client_with_storage +from tests.cases import TABLE_UPDATE +from tests.load.utils import yield_client_with_storage, empty_schema # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -32,6 +35,39 @@ def client() -> Iterator[SqlJobClientBase]: yield from yield_client_with_storage("snowflake") +def test_create_table_with_hints(client: SnowflakeClient, empty_schema: Schema) -> None: + mod_update = deepcopy(TABLE_UPDATE[:11]) + # mock hints + client.config.create_indexes = True + client.active_hints = SUPPORTED_HINTS + client.schema = empty_schema + + mod_update[0]["primary_key"] = True + mod_update[5]["primary_key"] = True + + mod_update[0]["sort"] = True + mod_update[4]["parent_key"] = True + + # unique constraints are always single columns + mod_update[1]["unique"] = True + mod_update[7]["unique"] = True + + sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, False)) + + print(sql) + client.sql_client.execute_sql(sql) + + # generate alter table + mod_update = deepcopy(TABLE_UPDATE[11:]) + mod_update[0]["primary_key"] = True + mod_update[1]["unique"] = True + + sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, True)) + + print(sql) + client.sql_client.execute_sql(sql) + + def test_query_tag(client: SnowflakeClient, mocker: MockerFixture): assert client.config.query_tag == QUERY_TAG # make sure we generate proper query diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index 1fc0034f43..43d4395188 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -6,7 +6,7 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema, utils from dlt.destinations import snowflake -from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient +from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient, SUPPORTED_HINTS from dlt.destinations.impl.snowflake.configuration import ( SnowflakeClientConfiguration, SnowflakeCredentials, @@ -66,16 +66,63 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None: assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql + + +def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: + mod_update = deepcopy(TABLE_UPDATE[:11]) + # mock hints + snowflake_client.config.create_indexes = True + snowflake_client.active_hints = SUPPORTED_HINTS + + mod_update[0]["primary_key"] = True + mod_update[5]["primary_key"] = True + + mod_update[0]["sort"] = True + + # unique constraints are always single columns + mod_update[1]["unique"] = True + mod_update[7]["unique"] = True + + mod_update[4]["parent_key"] = True + + sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, False)) + + assert sql.strip().startswith("CREATE TABLE") + assert "EVENT_TEST_TABLE" in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT UNIQUE NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL5" VARCHAR' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL7" BINARY' in sql + assert '"COL8" NUMBER(38,0) UNIQUE' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql + + # PRIMARY KEY constraint + assert 'CONSTRAINT "PK_EVENT_TEST_TABLE_' in sql + assert 'PRIMARY KEY ("COL1", "COL6")' in sql + + # generate alter + mod_update = deepcopy(TABLE_UPDATE[11:]) + mod_update[0]["primary_key"] = True + mod_update[1]["unique"] = True + + sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, True)) + # PK constraint ignored for alter + assert "PRIMARY KEY" not in sql + assert '"COL2_NULL" FLOAT UNIQUE' in sql def test_alter_table(snowflake_client: SnowflakeClient) -> None: @@ -90,15 +137,15 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: assert sql.count("ALTER TABLE") == 1 assert sql.count("ADD COLUMN") == 1 assert '"EVENT_TEST_TABLE"' in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql assert '"COL10" DATE' in sql mod_table = deepcopy(TABLE_UPDATE) @@ -106,7 +153,7 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0] assert '"COL1"' not in sql - assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None: