From 9f0d67e43a556f23517f0244e82bec74e44a7481 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:45:42 +0100 Subject: [PATCH 1/8] feat: snowflake hints --- .../impl/snowflake/configuration.py | 19 ++++++++++ dlt/destinations/impl/snowflake/snowflake.py | 16 ++++++--- .../snowflake/test_snowflake_table_builder.py | 36 +++++++++++++++++++ 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 4a89a1564b..4355edb09c 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -138,6 +138,25 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration) query_tag: Optional[str] = None """A tag with placeholders to tag sessions executing jobs""" + # TODO: decide name - create_indexes vs create_constraints (create_indexes used in other destinations) + create_indexes: bool = False + """Whether UNIQUE or PRIMARY KEY constrains should be created""" + + def __init__( + self, + *, + credentials: SnowflakeCredentials = None, + create_indexes: bool = False, + destination_name: str = None, + environment: str = None, + ) -> None: + super().__init__( + credentials=credentials, + destination_name=destination_name, + environment=environment, + ) + self.create_indexes = create_indexes + def fingerprint(self) -> str: """Returns a fingerprint of host part of a connection string""" if self.credentials and self.credentials.host: diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index e5146139f2..c6220fd65e 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, List +from typing import Optional, Sequence, List, Dict from urllib.parse import urlparse, urlunparse from dlt.common.data_writers.configuration import CsvFormatConfiguration @@ -17,7 +17,7 @@ ) from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema import TColumnSchema, Schema +from dlt.common.schema import TColumnSchema, Schema, TColumnHint from dlt.common.schema.typing import TColumnType from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS @@ -29,6 +29,8 @@ from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import ReferenceFollowupJobRequest +SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE", "primary_key": "PRIMARY KEY"} + class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( @@ -238,6 +240,7 @@ def __init__( self.config: SnowflakeClientConfiguration = config self.sql_client: SnowflakeSqlClient = sql_client # type: ignore self.type_mapper = self.capabilities.get_type_mapper() + self.active_hints = SUPPORTED_HINTS if self.config.create_indexes else {} def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False @@ -288,9 +291,14 @@ def _from_db_type( return self.type_mapper.from_destination_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) + hints_str = " ".join( + self.active_hints.get(h, "") + for h in self.active_hints.keys() + if c.get(h, False) is True + ) + column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index 1fc0034f43..4b8c4e1b2a 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -78,6 +78,42 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None: assert '"COL10" DATE NOT NULL' in sql +def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: + mod_update = deepcopy(TABLE_UPDATE) + + mod_update[0]["primary_key"] = True + mod_update[0]["sort"] = True + mod_update[1]["unique"] = True + mod_update[4]["parent_key"] = True + + sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, False)) + + assert sql.strip().startswith("CREATE TABLE") + assert "EVENT_TEST_TABLE" in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL5" VARCHAR' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL7" BINARY' in sql + assert '"COL8" NUMBER(38,0)' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql + + # same thing with indexes + snowflake_client = snowflake().client( + snowflake_client.schema, + SnowflakeClientConfiguration(create_indexes=True)._bind_dataset_name( + dataset_name="test_" + uniq_id() + ), + ) + sql = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)[0] + sqlfluff.parse(sql) + assert '"COL1" NUMBER(19,0) PRIMARY KEY NOT NULL' in sql + assert '"COL2" FLOAT UNIQUE NOT NULL' in sql + + def test_alter_table(snowflake_client: SnowflakeClient) -> None: statements = snowflake_client._get_table_update_sql("event_test_table", TABLE_UPDATE, True) assert len(statements) == 1 From 96003f458f71d41b4277296920ba85872cd4dc6b Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Mon, 16 Dec 2024 10:49:34 +0100 Subject: [PATCH 2/8] fix: statement extra space --- .../snowflake/test_snowflake_table_builder.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index 4b8c4e1b2a..bf55fe9dc6 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -66,16 +66,16 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None: assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: @@ -126,15 +126,15 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: assert sql.count("ALTER TABLE") == 1 assert sql.count("ADD COLUMN") == 1 assert '"EVENT_TEST_TABLE"' in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql assert '"COL10" DATE' in sql mod_table = deepcopy(TABLE_UPDATE) @@ -142,7 +142,7 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0] assert '"COL1"' not in sql - assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None: From ee71b1787814991d9e0d8314de581095557000e4 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:51:57 +0100 Subject: [PATCH 3/8] refactor snowflake constraints support --- dlt/destinations/impl/snowflake/snowflake.py | 71 +++++++++++++++++-- .../snowflake/test_snowflake_table_builder.py | 57 ++++++++------- 2 files changed, 96 insertions(+), 32 deletions(-) diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index c6220fd65e..da6597ecd4 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, List, Dict +from typing import Optional, Sequence, List, Dict, Set from urllib.parse import urlparse, urlunparse from dlt.common.data_writers.configuration import CsvFormatConfiguration @@ -267,6 +267,61 @@ def _make_add_column_sql( "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) ] + def _get_existing_constraints(self, table_name: str) -> Set[str]: + query = f""" + SELECT constraint_name + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS + WHERE TABLE_NAME = '{table_name.upper()}' + """ + + if self.sql_client.catalog_name: + query += f" AND CONSTRAINT_CATALOG = '{self.sql_client.catalog_name}'" + + with self.sql_client.open_connection() as conn: + cursors = conn.execute_string(query) + existing_names = set() + for cursor in cursors: + for row in cursor: + existing_names.add(row[0]) + return existing_names + + def _get_constraints_statement( + self, table_name: str, columns: Sequence[TColumnSchema], existing_constraints: Set[str] + ) -> List[str]: + statements = [] + pk_constraint_name = f"PK_{table_name.upper()}" + uq_constraint_name = f"UQ_{table_name.upper()}" + qualified_name = self.sql_client.make_qualified_table_name(table_name) + + pk_columns = [col["name"] for col in columns if col.get("primary_key")] + unique_columns = [col["name"] for col in columns if col.get("unique")] + + # Drop existing PK/UQ constraints if found + if pk_constraint_name in existing_constraints: + statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {pk_constraint_name}") + if uq_constraint_name in existing_constraints: + statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {uq_constraint_name}") + + # Add PK constraint if pk_columns exist + if pk_columns: + quoted_pk_cols = ", ".join(f'"{col}"' for col in pk_columns) + statements.append( + f"ALTER TABLE {qualified_name} " + f"ADD CONSTRAINT {pk_constraint_name} " + f"PRIMARY KEY ({quoted_pk_cols})" + ) + + # Add UNIQUE constraint if unique_columns exist + if unique_columns: + quoted_uq_cols = ", ".join(f'"{col}"' for col in unique_columns) + statements.append( + f"ALTER TABLE {qualified_name} " + f"ADD CONSTRAINT {uq_constraint_name} " + f"UNIQUE ({quoted_uq_cols})" + ) + + return statements + def _get_table_update_sql( self, table_name: str, @@ -283,6 +338,13 @@ def _get_table_update_sql( if cluster_list: sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")" + if self.active_hints: + existing_constraints = self._get_existing_constraints(table_name) + statements = self._get_constraints_statement( + table_name, new_columns, existing_constraints + ) + sql.extend(statements) + return sql def _from_db_type( @@ -291,14 +353,9 @@ def _from_db_type( return self.type_mapper.from_destination_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_str = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index bf55fe9dc6..e2bd27d18a 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -66,16 +66,16 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None: assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: @@ -90,16 +90,16 @@ def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql # same thing with indexes snowflake_client = snowflake().client( @@ -108,10 +108,17 @@ def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: dataset_name="test_" + uniq_id() ), ) - sql = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)[0] - sqlfluff.parse(sql) - assert '"COL1" NUMBER(19,0) PRIMARY KEY NOT NULL' in sql - assert '"COL2" FLOAT UNIQUE NOT NULL' in sql + sql_statements = snowflake_client._get_table_update_sql("event_test_table", mod_update, False) + + for stmt in sql_statements: + sqlfluff.parse(stmt) + + assert any( + 'ADD CONSTRAINT PK_EVENT_TEST_TABLE PRIMARY KEY ("col1")' in stmt for stmt in sql_statements + ) + assert any( + 'ADD CONSTRAINT UQ_EVENT_TEST_TABLE UNIQUE ("col2")' in stmt for stmt in sql_statements + ) def test_alter_table(snowflake_client: SnowflakeClient) -> None: @@ -126,15 +133,15 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: assert sql.count("ALTER TABLE") == 1 assert sql.count("ADD COLUMN") == 1 assert '"EVENT_TEST_TABLE"' in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql assert '"COL10" DATE' in sql mod_table = deepcopy(TABLE_UPDATE) @@ -142,7 +149,7 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0] assert '"COL1"' not in sql - assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None: From 07206d7d69ed865e535e38d1a81c8647a4f19fe1 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:54:24 +0100 Subject: [PATCH 4/8] revert changes --- dlt/destinations/impl/snowflake/configuration.py | 1 - dlt/destinations/impl/snowflake/snowflake.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 4355edb09c..2e589ea095 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -138,7 +138,6 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration) query_tag: Optional[str] = None """A tag with placeholders to tag sessions executing jobs""" - # TODO: decide name - create_indexes vs create_constraints (create_indexes used in other destinations) create_indexes: bool = False """Whether UNIQUE or PRIMARY KEY constrains should be created""" diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index da6597ecd4..eb74c143b1 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -353,9 +353,9 @@ def _from_db_type( return self.type_mapper.from_destination_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - column_name = self.sql_client.escape_column_name(c["name"]) + name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: From a6fcb85d88105881cba3ca17475e34c861f02cf4 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 17 Dec 2024 00:22:56 +0100 Subject: [PATCH 5/8] emits PKs for snowflake, CREATE only --- dlt/destinations/impl/snowflake/snowflake.py | 100 ++++++------------ tests/load/snowflake/test_snowflake_client.py | 40 ++++++- .../snowflake/test_snowflake_table_builder.py | 84 ++++++++------- 3 files changed, 113 insertions(+), 111 deletions(-) diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index eb74c143b1..786cdc0b77 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -1,6 +1,7 @@ from typing import Optional, Sequence, List, Dict, Set from urllib.parse import urlparse, urlunparse +from dlt.common import logger from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( @@ -15,13 +16,15 @@ AwsCredentialsWithoutDefaults, AzureCredentialsWithoutDefaults, ) +from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TColumnHint -from dlt.common.schema.typing import TColumnType +from dlt.common.schema.typing import TColumnType, TTableSchema from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS from dlt.common.typing import TLoaderFileFormat +from dlt.common.utils import uniq_id from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.exceptions import LoadJobTerminalException @@ -29,7 +32,7 @@ from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import ReferenceFollowupJobRequest -SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE", "primary_key": "PRIMARY KEY"} +SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE"} class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): @@ -267,60 +270,32 @@ def _make_add_column_sql( "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) ] - def _get_existing_constraints(self, table_name: str) -> Set[str]: - query = f""" - SELECT constraint_name - FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS - WHERE TABLE_NAME = '{table_name.upper()}' - """ - - if self.sql_client.catalog_name: - query += f" AND CONSTRAINT_CATALOG = '{self.sql_client.catalog_name}'" - - with self.sql_client.open_connection() as conn: - cursors = conn.execute_string(query) - existing_names = set() - for cursor in cursors: - for row in cursor: - existing_names.add(row[0]) - return existing_names - - def _get_constraints_statement( - self, table_name: str, columns: Sequence[TColumnSchema], existing_constraints: Set[str] - ) -> List[str]: - statements = [] - pk_constraint_name = f"PK_{table_name.upper()}" - uq_constraint_name = f"UQ_{table_name.upper()}" - qualified_name = self.sql_client.make_qualified_table_name(table_name) - - pk_columns = [col["name"] for col in columns if col.get("primary_key")] - unique_columns = [col["name"] for col in columns if col.get("unique")] - - # Drop existing PK/UQ constraints if found - if pk_constraint_name in existing_constraints: - statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {pk_constraint_name}") - if uq_constraint_name in existing_constraints: - statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {uq_constraint_name}") - - # Add PK constraint if pk_columns exist - if pk_columns: - quoted_pk_cols = ", ".join(f'"{col}"' for col in pk_columns) - statements.append( - f"ALTER TABLE {qualified_name} " - f"ADD CONSTRAINT {pk_constraint_name} " - f"PRIMARY KEY ({quoted_pk_cols})" - ) - - # Add UNIQUE constraint if unique_columns exist - if unique_columns: - quoted_uq_cols = ", ".join(f'"{col}"' for col in unique_columns) - statements.append( - f"ALTER TABLE {qualified_name} " - f"ADD CONSTRAINT {uq_constraint_name} " - f"UNIQUE ({quoted_uq_cols})" - ) - - return statements + def _get_constraints_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> str: + # "primary_key": "PRIMARY KEY" + if self.config.create_indexes: + partial: TTableSchema = { + "name": table_name, + "columns": {c["name"]: c for c in new_columns}, + } + # Add PK constraint if pk_columns exist + pk_columns = get_columns_names_with_prop(partial, "primary_key") + if pk_columns: + if generate_alter: + logger.warning( + f"PRIMARY KEY on {table_name} constraint cannot be added in ALTER TABLE and" + " is ignored" + ) + else: + pk_constraint_name = list( + self._norm_and_escape_columns(f"PK_{table_name}_{uniq_id(4)}") + )[0] + quoted_pk_cols = ", ".join( + self.sql_client.escape_column_name(col) for col in pk_columns + ) + return f",\nCONSTRAINT {pk_constraint_name} PRIMARY KEY ({quoted_pk_cols})" + return "" def _get_table_update_sql( self, @@ -338,13 +313,6 @@ def _get_table_update_sql( if cluster_list: sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")" - if self.active_hints: - existing_constraints = self._get_existing_constraints(table_name) - statements = self._get_constraints_statement( - table_name, new_columns, existing_constraints - ) - sql.extend(statements) - return sql def _from_db_type( @@ -352,11 +320,5 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" - ) - def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/tests/load/snowflake/test_snowflake_client.py b/tests/load/snowflake/test_snowflake_client.py index aebf514b56..674e01ba31 100644 --- a/tests/load/snowflake/test_snowflake_client.py +++ b/tests/load/snowflake/test_snowflake_client.py @@ -1,14 +1,17 @@ +from copy import deepcopy import os from typing import Iterator from pytest_mock import MockerFixture import pytest -from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient +from dlt.common.schema.schema import Schema +from dlt.destinations.impl.snowflake.snowflake import SUPPORTED_HINTS, SnowflakeClient from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.destinations.sql_client import TJobQueryTags -from tests.load.utils import yield_client_with_storage +from tests.cases import TABLE_UPDATE +from tests.load.utils import yield_client_with_storage, empty_schema # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -32,6 +35,39 @@ def client() -> Iterator[SqlJobClientBase]: yield from yield_client_with_storage("snowflake") +def test_create_table_with_hints(client: SnowflakeClient, empty_schema: Schema) -> None: + mod_update = deepcopy(TABLE_UPDATE[:11]) + # mock hints + client.config.create_indexes = True + client.active_hints = SUPPORTED_HINTS + client.schema = empty_schema + + mod_update[0]["primary_key"] = True + mod_update[5]["primary_key"] = True + + mod_update[0]["sort"] = True + mod_update[4]["parent_key"] = True + + # unique constraints are always single columns + mod_update[1]["unique"] = True + mod_update[7]["unique"] = True + + sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, False)) + + print(sql) + client.sql_client.execute_sql(sql) + + # generate alter table + mod_update = deepcopy(TABLE_UPDATE[11:]) + mod_update[0]["primary_key"] = True + mod_update[1]["unique"] = True + + sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, True)) + + print(sql) + client.sql_client.execute_sql(sql) + + def test_query_tag(client: SnowflakeClient, mocker: MockerFixture): assert client.config.query_tag == QUERY_TAG # make sure we generate proper query diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index e2bd27d18a..43d4395188 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -6,7 +6,7 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema, utils from dlt.destinations import snowflake -from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient +from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient, SUPPORTED_HINTS from dlt.destinations.impl.snowflake.configuration import ( SnowflakeClientConfiguration, SnowflakeCredentials, @@ -66,59 +66,63 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None: assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: - mod_update = deepcopy(TABLE_UPDATE) + mod_update = deepcopy(TABLE_UPDATE[:11]) + # mock hints + snowflake_client.config.create_indexes = True + snowflake_client.active_hints = SUPPORTED_HINTS mod_update[0]["primary_key"] = True + mod_update[5]["primary_key"] = True + mod_update[0]["sort"] = True + + # unique constraints are always single columns mod_update[1]["unique"] = True + mod_update[7]["unique"] = True + mod_update[4]["parent_key"] = True sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, False)) assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT UNIQUE NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql - assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL8" NUMBER(38,0) UNIQUE' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql - # same thing with indexes - snowflake_client = snowflake().client( - snowflake_client.schema, - SnowflakeClientConfiguration(create_indexes=True)._bind_dataset_name( - dataset_name="test_" + uniq_id() - ), - ) - sql_statements = snowflake_client._get_table_update_sql("event_test_table", mod_update, False) + # PRIMARY KEY constraint + assert 'CONSTRAINT "PK_EVENT_TEST_TABLE_' in sql + assert 'PRIMARY KEY ("COL1", "COL6")' in sql - for stmt in sql_statements: - sqlfluff.parse(stmt) + # generate alter + mod_update = deepcopy(TABLE_UPDATE[11:]) + mod_update[0]["primary_key"] = True + mod_update[1]["unique"] = True - assert any( - 'ADD CONSTRAINT PK_EVENT_TEST_TABLE PRIMARY KEY ("col1")' in stmt for stmt in sql_statements - ) - assert any( - 'ADD CONSTRAINT UQ_EVENT_TEST_TABLE UNIQUE ("col2")' in stmt for stmt in sql_statements - ) + sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, True)) + # PK constraint ignored for alter + assert "PRIMARY KEY" not in sql + assert '"COL2_NULL" FLOAT UNIQUE' in sql def test_alter_table(snowflake_client: SnowflakeClient) -> None: @@ -133,15 +137,15 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: assert sql.count("ALTER TABLE") == 1 assert sql.count("ADD COLUMN") == 1 assert '"EVENT_TEST_TABLE"' in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql assert '"COL10" DATE' in sql mod_table = deepcopy(TABLE_UPDATE) @@ -149,7 +153,7 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0] assert '"COL1"' not in sql - assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None: From f2bac9642caafc56ba6ff834326086d306428ebb Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 17 Dec 2024 00:23:37 +0100 Subject: [PATCH 6/8] refactors column and constraint sql in job client --- .../impl/databricks/databricks.py | 6 ---- dlt/destinations/impl/dremio/dremio.py | 6 ---- dlt/destinations/impl/duckdb/duck.py | 11 ------ dlt/destinations/impl/mssql/mssql.py | 6 +--- dlt/destinations/impl/postgres/postgres.py | 12 ------- dlt/destinations/impl/redshift/redshift.py | 12 +------ dlt/destinations/job_client_impl.py | 36 ++++++++++++++----- 7 files changed, 29 insertions(+), 60 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 2bb68a607e..a83db6ec34 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -264,12 +264,6 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" - ) - def _get_storage_table_query_columns(self) -> List[str]: fields = super()._get_storage_table_query_columns() fields[2] = ( # Override because this is the only way to get data type with precision diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index ab23f58ab4..e3a090c824 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -151,12 +151,6 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" - ) - def _create_merge_followup_jobs( self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 3bd4c83e1f..2b3370270b 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -74,17 +74,6 @@ def create_load_job( job = DuckDbCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_str = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) - column_name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" - ) - def _from_db_type( self, pq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index 27aebe07f2..7b48a6b551 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -115,11 +115,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = Non else: db_type = self.type_mapper.to_destination_type(c, table) - hints_str = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) + hints_str = self._get_column_hints_sql(c) column_name = self.sql_client.escape_column_name(c["name"]) return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 2459ee1dbe..3d54b59f93 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -161,18 +161,6 @@ def create_load_job( job = PostgresCsvCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_ = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) - column_name = self.sql_client.escape_column_name(c["name"]) - nullability = self._gen_not_null(c.get("nullable", True)) - column_type = self.type_mapper.to_destination_type(c, table) - - return f"{column_name} {column_type} {hints_} {nullability}" - def _create_replace_followup_jobs( self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 2335166761..b1aa37ce6a 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -153,6 +153,7 @@ def __init__( capabilities, ) super().__init__(schema, config, sql_client) + self.active_hints = HINT_TO_REDSHIFT_ATTR self.sql_client = sql_client self.config: RedshiftClientConfiguration = config self.type_mapper = self.capabilities.get_type_mapper() @@ -162,17 +163,6 @@ def _create_merge_followup_jobs( ) -> List[FollowupJobRequest]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_str = " ".join( - HINT_TO_REDSHIFT_ATTR.get(h, "") - for h in HINT_TO_REDSHIFT_ATTR.keys() - if c.get(h, False) is True - ) - column_name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" - ) - def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index d1f211b1e9..12cb129812 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -7,6 +7,7 @@ from typing import ( Any, ClassVar, + Dict, List, Optional, Sequence, @@ -14,21 +15,18 @@ Type, Iterable, Iterator, - Generator, ) import zlib import re -from contextlib import contextmanager -from contextlib import suppress from dlt.common import pendulum, logger +from dlt.common.destination.capabilities import DataTypeMapper from dlt.common.json import json from dlt.common.schema.typing import ( C_DLT_LOAD_ID, COLUMN_HINTS, TColumnType, TColumnSchemaBase, - TTableFormat, ) from dlt.common.schema.utils import ( get_inherited_table_hint, @@ -40,11 +38,11 @@ from dlt.common.storages import FileStorage from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables +from dlt.common.schema import TColumnHint from dlt.common.destination.reference import ( PreparedTableSchema, StateInfo, StorageSchemaInfo, - SupportsReadableDataset, WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, @@ -55,9 +53,7 @@ JobClientBase, HasFollowupJobs, CredentialsConfiguration, - SupportsReadableRelation, ) -from dlt.destinations.dataset import ReadableDBAPIDataset from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations.job_impl import ( @@ -154,6 +150,8 @@ def __init__( self.state_table_columns = ", ".join( sql_client.escape_column_name(col) for col in state_table_["columns"] ) + self.active_hints: Dict[TColumnHint, str] = {} + self.type_mapper: DataTypeMapper = None super().__init__(schema, config, sql_client.capabilities) self.sql_client = sql_client assert isinstance(config, DestinationClientDwhConfiguration) @@ -569,6 +567,7 @@ def _get_table_update_sql( # build CREATE sql = self._make_create_table(qualified_name, table) + " (\n" sql += ",\n".join([self._get_column_def_sql(c, table) for c in new_columns]) + sql += self._get_constraints_sql(table_name, new_columns, generate_alter) sql += ")" sql_result.append(sql) else: @@ -582,8 +581,16 @@ def _get_table_update_sql( sql_result.extend( [sql_base + col_statement for col_statement in add_column_statements] ) + constraints_sql = self._get_constraints_sql(table_name, new_columns, generate_alter) + if constraints_sql: + sql_result.append(constraints_sql) return sql_result + def _get_constraints_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> str: + return "" + def _check_table_update_hints( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> None: @@ -613,9 +620,20 @@ def _check_table_update_hints( " existing tables." ) - @abstractmethod def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - pass + hints_ = self._get_column_hints_sql(c) + column_name = self.sql_client.escape_column_name(c["name"]) + nullability = self._gen_not_null(c.get("nullable", True)) + column_type = self.type_mapper.to_destination_type(c, table) + + return f"{column_name} {column_type} {hints_} {nullability}" + + def _get_column_hints_sql(self, c: TColumnSchema) -> str: + return " ".join( + self.active_hints.get(h, "") + for h in self.active_hints.keys() + if c.get(h, False) is True # use ColumnPropInfos to get default value + ) @staticmethod def _gen_not_null(nullable: bool) -> str: From 95069389f1f709d6d9183e959fc7deba7517e670 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 17 Dec 2024 10:51:14 +0100 Subject: [PATCH 7/8] fixes tests, adds docs --- dlt/destinations/impl/bigquery/bigquery.py | 5 +---- dlt/destinations/impl/clickhouse/clickhouse.py | 5 ++--- dlt/destinations/job_client_impl.py | 3 +-- .../website/docs/dlt-ecosystem/destinations/snowflake.md | 9 +++++++++ tests/load/dremio/test_dremio_client.py | 8 ++++---- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 2b3927e7c9..10a344f768 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -401,10 +401,7 @@ def _get_info_schema_columns_query( return query, folded_table_names def _get_column_def_sql(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(column["name"]) - column_def_sql = ( - f"{name} {self.type_mapper.to_destination_type(column, table)} {self._gen_not_null(column.get('nullable', True))}" - ) + column_def_sql = super()._get_column_def_sql(column, table) if column.get(ROUND_HALF_EVEN_HINT, False): column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')" if column.get(ROUND_HALF_AWAY_FROM_ZERO_HINT, False): diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 3a5f5c3e28..a407e56361 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -292,11 +292,10 @@ def _get_table_update_sql( return sql - @staticmethod - def _gen_not_null(v: bool) -> str: + def _gen_not_null(self, v: bool) -> str: # ClickHouse fields are not nullable by default. # We use the `Nullable` modifier instead of NULL / NOT NULL modifiers to cater for ALTER statement. - pass + return "" def _from_db_type( self, ch_t: str, precision: Optional[int], scale: Optional[int] diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 12cb129812..888c80c006 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -635,8 +635,7 @@ def _get_column_hints_sql(self, c: TColumnSchema) -> str: if c.get(h, False) is True # use ColumnPropInfos to get default value ) - @staticmethod - def _gen_not_null(nullable: bool) -> str: + def _gen_not_null(self, nullable: bool) -> str: return "NOT NULL" if not nullable else "" def _create_table_update( diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index 07cf822973..28684c39ac 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -200,6 +200,12 @@ Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and ## Supported column hints Snowflake supports the following [column hints](../../general-usage/schema#tables-and-columns): * `cluster` - Creates a cluster column(s). Many columns per table are supported and only when a new table is created. +* `unique` - Creates UNIQUE hint on a Snowflake column, can be added to many columns. ([optional](#additional-destination-options)) +* `primary_key` - Creates PRIMARY KEY on selected column(s), may be compound. ([optional](#additional-destination-options)) + +`unique` and `primary_key` are not enforced and `dlt` does not instruct Snowflake to `RELY` on them when +query planning. + ## Table and column identifiers Snowflake supports both case-sensitive and case-insensitive identifiers. All unquoted and uppercase identifiers resolve case-insensitively in SQL statements. Case-insensitive [naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) like the default **snake_case** will generate case-insensitive identifiers. Case-sensitive (like **sql_cs_v1**) will generate @@ -308,6 +314,7 @@ pipeline = dlt.pipeline( ## Additional destination options You can define your own stage to PUT files and disable the removal of the staged files after loading. +You can also opt-in to [create indexes](#supported-column-hints). ```toml [destination.snowflake] @@ -315,6 +322,8 @@ You can define your own stage to PUT files and disable the removal of the staged stage_name="DLT_STAGE" # Whether to keep or delete the staged files after COPY INTO succeeds keep_staged_files=true +# Add UNIQUE and PRIMARY KEY hints to tables +create_indexes=true ``` ### Setting up CSV format diff --git a/tests/load/dremio/test_dremio_client.py b/tests/load/dremio/test_dremio_client.py index efc72c0652..98212efb13 100644 --- a/tests/load/dremio/test_dremio_client.py +++ b/tests/load/dremio/test_dremio_client.py @@ -48,12 +48,12 @@ def test_dremio_factory() -> None: [ TColumnSchema(name="foo", data_type="text", partition=True), TColumnSchema(name="bar", data_type="bigint", sort=True), - TColumnSchema(name="baz", data_type="double"), + TColumnSchema(name="baz", data_type="double", nullable=False), ], False, [ 'CREATE TABLE "test_database"."test_dataset"."event_test_table"' - ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )\nPARTITION BY' + ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE NOT NULL)\nPARTITION BY' ' ("foo")\nLOCALSORT BY ("bar")' ], ), @@ -66,7 +66,7 @@ def test_dremio_factory() -> None: False, [ 'CREATE TABLE "test_database"."test_dataset"."event_test_table"' - ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )\nPARTITION BY' + ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )\nPARTITION BY' ' ("foo","bar")' ], ), @@ -79,7 +79,7 @@ def test_dremio_factory() -> None: False, [ 'CREATE TABLE "test_database"."test_dataset"."event_test_table"' - ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )' + ' (\n"foo" VARCHAR ,\n"bar" BIGINT ,\n"baz" DOUBLE )' ], ), ], From fb79d80eefac1c92a34cca267f3c6b58a87264c3 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 17 Dec 2024 15:07:22 +0100 Subject: [PATCH 8/8] fixes bigquery table builder tests --- .../bigquery/test_bigquery_table_builder.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 56a674cfa3..b2857b7c08 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -107,25 +107,25 @@ def test_create_table(gcp_client: BigQueryClient) -> None: sqlfluff.parse(sql, dialect="bigquery") assert sql.startswith("CREATE TABLE") assert "event_test_table" in sql - assert "`col1` INT64 NOT NULL" in sql - assert "`col2` FLOAT64 NOT NULL" in sql - assert "`col3` BOOL NOT NULL" in sql - assert "`col4` TIMESTAMP NOT NULL" in sql + assert "`col1` INT64 NOT NULL" in sql + assert "`col2` FLOAT64 NOT NULL" in sql + assert "`col3` BOOL NOT NULL" in sql + assert "`col4` TIMESTAMP NOT NULL" in sql assert "`col5` STRING " in sql - assert "`col6` NUMERIC(38,9) NOT NULL" in sql + assert "`col6` NUMERIC(38,9) NOT NULL" in sql assert "`col7` BYTES" in sql assert "`col8` BIGNUMERIC" in sql - assert "`col9` JSON NOT NULL" in sql + assert "`col9` JSON NOT NULL" in sql assert "`col10` DATE" in sql assert "`col11` TIME" in sql - assert "`col1_precision` INT64 NOT NULL" in sql - assert "`col4_precision` TIMESTAMP NOT NULL" in sql + assert "`col1_precision` INT64 NOT NULL" in sql + assert "`col4_precision` TIMESTAMP NOT NULL" in sql assert "`col5_precision` STRING(25) " in sql - assert "`col6_precision` NUMERIC(6,2) NOT NULL" in sql + assert "`col6_precision` NUMERIC(6,2) NOT NULL" in sql assert "`col7_precision` BYTES(19)" in sql - assert "`col11_precision` TIME NOT NULL" in sql - assert "`col_high_p_decimal` BIGNUMERIC(76,0) NOT NULL" in sql - assert "`col_high_s_decimal` BIGNUMERIC(38,24) NOT NULL" in sql + assert "`col11_precision` TIME NOT NULL" in sql + assert "`col_high_p_decimal` BIGNUMERIC(76,0) NOT NULL" in sql + assert "`col_high_s_decimal` BIGNUMERIC(38,24) NOT NULL" in sql assert "CLUSTER BY" not in sql assert "PARTITION BY" not in sql @@ -137,29 +137,29 @@ def test_alter_table(gcp_client: BigQueryClient) -> None: assert sql.startswith("ALTER TABLE") assert sql.count("ALTER TABLE") == 1 assert "event_test_table" in sql - assert "ADD COLUMN `col1` INT64 NOT NULL" in sql - assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql - assert "ADD COLUMN `col3` BOOL NOT NULL" in sql - assert "ADD COLUMN `col4` TIMESTAMP NOT NULL" in sql + assert "ADD COLUMN `col1` INT64 NOT NULL" in sql + assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql + assert "ADD COLUMN `col3` BOOL NOT NULL" in sql + assert "ADD COLUMN `col4` TIMESTAMP NOT NULL" in sql assert "ADD COLUMN `col5` STRING" in sql - assert "ADD COLUMN `col6` NUMERIC(38,9) NOT NULL" in sql + assert "ADD COLUMN `col6` NUMERIC(38,9) NOT NULL" in sql assert "ADD COLUMN `col7` BYTES" in sql assert "ADD COLUMN `col8` BIGNUMERIC" in sql - assert "ADD COLUMN `col9` JSON NOT NULL" in sql + assert "ADD COLUMN `col9` JSON NOT NULL" in sql assert "ADD COLUMN `col10` DATE" in sql assert "ADD COLUMN `col11` TIME" in sql - assert "ADD COLUMN `col1_precision` INT64 NOT NULL" in sql - assert "ADD COLUMN `col4_precision` TIMESTAMP NOT NULL" in sql + assert "ADD COLUMN `col1_precision` INT64 NOT NULL" in sql + assert "ADD COLUMN `col4_precision` TIMESTAMP NOT NULL" in sql assert "ADD COLUMN `col5_precision` STRING(25)" in sql - assert "ADD COLUMN `col6_precision` NUMERIC(6,2) NOT NULL" in sql + assert "ADD COLUMN `col6_precision` NUMERIC(6,2) NOT NULL" in sql assert "ADD COLUMN `col7_precision` BYTES(19)" in sql - assert "ADD COLUMN `col11_precision` TIME NOT NULL" in sql + assert "ADD COLUMN `col11_precision` TIME NOT NULL" in sql # table has col1 already in storage mod_table = deepcopy(TABLE_UPDATE) mod_table.pop(0) sql = gcp_client._get_table_update_sql("event_test_table", mod_table, True)[0] - assert "ADD COLUMN `col1` INTEGER NOT NULL" not in sql - assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql + assert "ADD COLUMN `col1` INTEGER NOT NULL" not in sql + assert "ADD COLUMN `col2` FLOAT64 NOT NULL" in sql def test_create_table_case_insensitive(ci_gcp_client: BigQueryClient) -> None: