Skip to content

Commit

Permalink
emits PKs for snowflake, CREATE only
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Dec 16, 2024
1 parent e912c51 commit a6fcb85
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 111 deletions.
100 changes: 31 additions & 69 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Sequence, List, Dict, Set
from urllib.parse import urlparse, urlunparse

from dlt.common import logger
from dlt.common.data_writers.configuration import CsvFormatConfiguration
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import (
Expand All @@ -15,21 +16,23 @@
AwsCredentialsWithoutDefaults,
AzureCredentialsWithoutDefaults,
)
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TColumnHint
from dlt.common.schema.typing import TColumnType
from dlt.common.schema.typing import TColumnType, TTableSchema

from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS
from dlt.common.typing import TLoaderFileFormat
from dlt.common.utils import uniq_id
from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset
from dlt.destinations.exceptions import LoadJobTerminalException

from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
from dlt.destinations.job_impl import ReferenceFollowupJobRequest

SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE", "primary_key": "PRIMARY KEY"}
SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE"}


class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs):
Expand Down Expand Up @@ -267,60 +270,32 @@ def _make_add_column_sql(
"ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns)
]

def _get_existing_constraints(self, table_name: str) -> Set[str]:
query = f"""
SELECT constraint_name
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS
WHERE TABLE_NAME = '{table_name.upper()}'
"""

if self.sql_client.catalog_name:
query += f" AND CONSTRAINT_CATALOG = '{self.sql_client.catalog_name}'"

with self.sql_client.open_connection() as conn:
cursors = conn.execute_string(query)
existing_names = set()
for cursor in cursors:
for row in cursor:
existing_names.add(row[0])
return existing_names

def _get_constraints_statement(
self, table_name: str, columns: Sequence[TColumnSchema], existing_constraints: Set[str]
) -> List[str]:
statements = []
pk_constraint_name = f"PK_{table_name.upper()}"
uq_constraint_name = f"UQ_{table_name.upper()}"
qualified_name = self.sql_client.make_qualified_table_name(table_name)

pk_columns = [col["name"] for col in columns if col.get("primary_key")]
unique_columns = [col["name"] for col in columns if col.get("unique")]

# Drop existing PK/UQ constraints if found
if pk_constraint_name in existing_constraints:
statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {pk_constraint_name}")
if uq_constraint_name in existing_constraints:
statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {uq_constraint_name}")

# Add PK constraint if pk_columns exist
if pk_columns:
quoted_pk_cols = ", ".join(f'"{col}"' for col in pk_columns)
statements.append(
f"ALTER TABLE {qualified_name} "
f"ADD CONSTRAINT {pk_constraint_name} "
f"PRIMARY KEY ({quoted_pk_cols})"
)

# Add UNIQUE constraint if unique_columns exist
if unique_columns:
quoted_uq_cols = ", ".join(f'"{col}"' for col in unique_columns)
statements.append(
f"ALTER TABLE {qualified_name} "
f"ADD CONSTRAINT {uq_constraint_name} "
f"UNIQUE ({quoted_uq_cols})"
)

return statements
def _get_constraints_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> str:
# "primary_key": "PRIMARY KEY"
if self.config.create_indexes:
partial: TTableSchema = {
"name": table_name,
"columns": {c["name"]: c for c in new_columns},
}
# Add PK constraint if pk_columns exist
pk_columns = get_columns_names_with_prop(partial, "primary_key")
if pk_columns:
if generate_alter:
logger.warning(
f"PRIMARY KEY on {table_name} constraint cannot be added in ALTER TABLE and"
" is ignored"
)
else:
pk_constraint_name = list(
self._norm_and_escape_columns(f"PK_{table_name}_{uniq_id(4)}")
)[0]
quoted_pk_cols = ", ".join(
self.sql_client.escape_column_name(col) for col in pk_columns
)
return f",\nCONSTRAINT {pk_constraint_name} PRIMARY KEY ({quoted_pk_cols})"
return ""

def _get_table_update_sql(
self,
Expand All @@ -338,25 +313,12 @@ def _get_table_update_sql(
if cluster_list:
sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")"

if self.active_hints:
existing_constraints = self._get_existing_constraints(table_name)
statements = self._get_constraints_statement(
table_name, new_columns, existing_constraints
)
sql.extend(statements)

return sql

def _from_db_type(
self, bq_t: str, precision: Optional[int], scale: Optional[int]
) -> TColumnType:
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool:
return self.config.truncate_tables_on_staging_destination_before_load
40 changes: 38 additions & 2 deletions tests/load/snowflake/test_snowflake_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from copy import deepcopy
import os
from typing import Iterator
from pytest_mock import MockerFixture
import pytest

from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient
from dlt.common.schema.schema import Schema
from dlt.destinations.impl.snowflake.snowflake import SUPPORTED_HINTS, SnowflakeClient
from dlt.destinations.job_client_impl import SqlJobClientBase

from dlt.destinations.sql_client import TJobQueryTags

from tests.load.utils import yield_client_with_storage
from tests.cases import TABLE_UPDATE
from tests.load.utils import yield_client_with_storage, empty_schema

# mark all tests as essential, do not remove
pytestmark = pytest.mark.essential
Expand All @@ -32,6 +35,39 @@ def client() -> Iterator[SqlJobClientBase]:
yield from yield_client_with_storage("snowflake")


def test_create_table_with_hints(client: SnowflakeClient, empty_schema: Schema) -> None:
mod_update = deepcopy(TABLE_UPDATE[:11])
# mock hints
client.config.create_indexes = True
client.active_hints = SUPPORTED_HINTS
client.schema = empty_schema

mod_update[0]["primary_key"] = True
mod_update[5]["primary_key"] = True

mod_update[0]["sort"] = True
mod_update[4]["parent_key"] = True

# unique constraints are always single columns
mod_update[1]["unique"] = True
mod_update[7]["unique"] = True

sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, False))

print(sql)
client.sql_client.execute_sql(sql)

# generate alter table
mod_update = deepcopy(TABLE_UPDATE[11:])
mod_update[0]["primary_key"] = True
mod_update[1]["unique"] = True

sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, True))

print(sql)
client.sql_client.execute_sql(sql)


def test_query_tag(client: SnowflakeClient, mocker: MockerFixture):
assert client.config.query_tag == QUERY_TAG
# make sure we generate proper query
Expand Down
84 changes: 44 additions & 40 deletions tests/load/snowflake/test_snowflake_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dlt.common.utils import uniq_id
from dlt.common.schema import Schema, utils
from dlt.destinations import snowflake
from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient
from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient, SUPPORTED_HINTS
from dlt.destinations.impl.snowflake.configuration import (
SnowflakeClientConfiguration,
SnowflakeCredentials,
Expand Down Expand Up @@ -66,59 +66,63 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None:

assert sql.strip().startswith("CREATE TABLE")
assert "EVENT_TEST_TABLE" in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL5" VARCHAR' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL7" BINARY' in sql
assert '"COL8" NUMBER(38,0)' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql


def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None:
mod_update = deepcopy(TABLE_UPDATE)
mod_update = deepcopy(TABLE_UPDATE[:11])
# mock hints
snowflake_client.config.create_indexes = True
snowflake_client.active_hints = SUPPORTED_HINTS

mod_update[0]["primary_key"] = True
mod_update[5]["primary_key"] = True

mod_update[0]["sort"] = True

# unique constraints are always single columns
mod_update[1]["unique"] = True
mod_update[7]["unique"] = True

mod_update[4]["parent_key"] = True

sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, False))

assert sql.strip().startswith("CREATE TABLE")
assert "EVENT_TEST_TABLE" in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT UNIQUE NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL5" VARCHAR' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL7" BINARY' in sql
assert '"COL8" NUMBER(38,0)' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql
assert '"COL8" NUMBER(38,0) UNIQUE' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql

# same thing with indexes
snowflake_client = snowflake().client(
snowflake_client.schema,
SnowflakeClientConfiguration(create_indexes=True)._bind_dataset_name(
dataset_name="test_" + uniq_id()
),
)
sql_statements = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)
# PRIMARY KEY constraint
assert 'CONSTRAINT "PK_EVENT_TEST_TABLE_' in sql
assert 'PRIMARY KEY ("COL1", "COL6")' in sql

for stmt in sql_statements:
sqlfluff.parse(stmt)
# generate alter
mod_update = deepcopy(TABLE_UPDATE[11:])
mod_update[0]["primary_key"] = True
mod_update[1]["unique"] = True

assert any(
'ADD CONSTRAINT PK_EVENT_TEST_TABLE PRIMARY KEY ("col1")' in stmt for stmt in sql_statements
)
assert any(
'ADD CONSTRAINT UQ_EVENT_TEST_TABLE UNIQUE ("col2")' in stmt for stmt in sql_statements
)
sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, True))
# PK constraint ignored for alter
assert "PRIMARY KEY" not in sql
assert '"COL2_NULL" FLOAT UNIQUE' in sql


def test_alter_table(snowflake_client: SnowflakeClient) -> None:
Expand All @@ -133,23 +137,23 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None:
assert sql.count("ALTER TABLE") == 1
assert sql.count("ADD COLUMN") == 1
assert '"EVENT_TEST_TABLE"' in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL5" VARCHAR' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL7" BINARY' in sql
assert '"COL8" NUMBER(38,0)' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE' in sql

mod_table = deepcopy(TABLE_UPDATE)
mod_table.pop(0)
sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0]

assert '"COL1"' not in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql


def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None:
Expand Down

0 comments on commit a6fcb85

Please sign in to comment.