Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: snowflake hints #2143

Closed
wants to merge 10 commits into from
5 changes: 1 addition & 4 deletions dlt/destinations/impl/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,7 @@ def _get_info_schema_columns_query(
return query, folded_table_names

def _get_column_def_sql(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(column["name"])
column_def_sql = (
f"{name} {self.type_mapper.to_destination_type(column, table)} {self._gen_not_null(column.get('nullable', True))}"
)
column_def_sql = super()._get_column_def_sql(column, table)
if column.get(ROUND_HALF_EVEN_HINT, False):
column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')"
if column.get(ROUND_HALF_AWAY_FROM_ZERO_HINT, False):
Expand Down
5 changes: 2 additions & 3 deletions dlt/destinations/impl/clickhouse/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,10 @@ def _get_table_update_sql(

return sql

@staticmethod
def _gen_not_null(v: bool) -> str:
def _gen_not_null(self, v: bool) -> str:
# ClickHouse fields are not nullable by default.
# We use the `Nullable` modifier instead of NULL / NOT NULL modifiers to cater for ALTER statement.
pass
return ""

def _from_db_type(
self, ch_t: str, precision: Optional[int], scale: Optional[int]
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,6 @@ def _from_db_type(
) -> TColumnType:
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def _get_storage_table_query_columns(self) -> List[str]:
fields = super()._get_storage_table_query_columns()
fields[2] = ( # Override because this is the only way to get data type with precision
Expand Down
6 changes: 0 additions & 6 deletions dlt/destinations/impl/dremio/dremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,6 @@ def _from_db_type(
) -> TColumnType:
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def _create_merge_followup_jobs(
self, table_chain: Sequence[PreparedTableSchema]
) -> List[FollowupJobRequest]:
Expand Down
11 changes: 0 additions & 11 deletions dlt/destinations/impl/duckdb/duck.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,6 @@ def create_load_job(
job = DuckDbCopyJob(file_path)
return job

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_str = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
return (
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
)

def _from_db_type(
self, pq_t: str, precision: Optional[int], scale: Optional[int]
) -> TColumnType:
Expand Down
6 changes: 1 addition & 5 deletions dlt/destinations/impl/mssql/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = Non
else:
db_type = self.type_mapper.to_destination_type(c, table)

hints_str = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
hints_str = self._get_column_hints_sql(c)
column_name = self.sql_client.escape_column_name(c["name"])
return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}"

Expand Down
12 changes: 0 additions & 12 deletions dlt/destinations/impl/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,6 @@ def create_load_job(
job = PostgresCsvCopyJob(file_path)
return job

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_ = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
nullability = self._gen_not_null(c.get("nullable", True))
column_type = self.type_mapper.to_destination_type(c, table)

return f"{column_name} {column_type} {hints_} {nullability}"

def _create_replace_followup_jobs(
self, table_chain: Sequence[PreparedTableSchema]
) -> List[FollowupJobRequest]:
Expand Down
12 changes: 1 addition & 11 deletions dlt/destinations/impl/redshift/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
capabilities,
)
super().__init__(schema, config, sql_client)
self.active_hints = HINT_TO_REDSHIFT_ATTR
self.sql_client = sql_client
self.config: RedshiftClientConfiguration = config
self.type_mapper = self.capabilities.get_type_mapper()
Expand All @@ -162,17 +163,6 @@ def _create_merge_followup_jobs(
) -> List[FollowupJobRequest]:
return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)]

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_str = " ".join(
HINT_TO_REDSHIFT_ATTR.get(h, "")
for h in HINT_TO_REDSHIFT_ATTR.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
return (
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
)

def create_load_job(
self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False
) -> LoadJob:
Expand Down
18 changes: 18 additions & 0 deletions dlt/destinations/impl/snowflake/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,24 @@ class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration)
query_tag: Optional[str] = None
"""A tag with placeholders to tag sessions executing jobs"""

create_indexes: bool = False
"""Whether UNIQUE or PRIMARY KEY constrains should be created"""

def __init__(
self,
*,
credentials: SnowflakeCredentials = None,
create_indexes: bool = False,
destination_name: str = None,
environment: str = None,
) -> None:
super().__init__(
credentials=credentials,
destination_name=destination_name,
environment=environment,
)
self.create_indexes = create_indexes

def fingerprint(self) -> str:
"""Returns a fingerprint of host part of a connection string"""
if self.credentials and self.credentials.host:
Expand Down
45 changes: 36 additions & 9 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Sequence, List
from typing import Optional, Sequence, List, Dict, Set
from urllib.parse import urlparse, urlunparse

from dlt.common import logger
from dlt.common.data_writers.configuration import CsvFormatConfiguration
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import (
Expand All @@ -15,20 +16,24 @@
AwsCredentialsWithoutDefaults,
AzureCredentialsWithoutDefaults,
)
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema.typing import TColumnType
from dlt.common.schema import TColumnSchema, Schema, TColumnHint
from dlt.common.schema.typing import TColumnType, TTableSchema

from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS
from dlt.common.typing import TLoaderFileFormat
from dlt.common.utils import uniq_id
from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset
from dlt.destinations.exceptions import LoadJobTerminalException

from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration
from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient
from dlt.destinations.job_impl import ReferenceFollowupJobRequest

SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE"}


class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs):
def __init__(
Expand Down Expand Up @@ -238,6 +243,7 @@ def __init__(
self.config: SnowflakeClientConfiguration = config
self.sql_client: SnowflakeSqlClient = sql_client # type: ignore
self.type_mapper = self.capabilities.get_type_mapper()
self.active_hints = SUPPORTED_HINTS if self.config.create_indexes else {}

def create_load_job(
self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False
Expand All @@ -264,6 +270,33 @@ def _make_add_column_sql(
"ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns)
]

def _get_constraints_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> str:
# "primary_key": "PRIMARY KEY"
if self.config.create_indexes:
partial: TTableSchema = {
"name": table_name,
"columns": {c["name"]: c for c in new_columns},
}
# Add PK constraint if pk_columns exist
pk_columns = get_columns_names_with_prop(partial, "primary_key")
if pk_columns:
if generate_alter:
logger.warning(
f"PRIMARY KEY on {table_name} constraint cannot be added in ALTER TABLE and"
" is ignored"
)
else:
pk_constraint_name = list(
self._norm_and_escape_columns(f"PK_{table_name}_{uniq_id(4)}")
)[0]
quoted_pk_cols = ", ".join(
self.sql_client.escape_column_name(col) for col in pk_columns
)
return f",\nCONSTRAINT {pk_constraint_name} PRIMARY KEY ({quoted_pk_cols})"
return ""

def _get_table_update_sql(
self,
table_name: str,
Expand All @@ -287,11 +320,5 @@ def _from_db_type(
) -> TColumnType:
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
name = self.sql_client.escape_column_name(c["name"])
return (
f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool:
return self.config.truncate_tables_on_staging_destination_before_load
39 changes: 28 additions & 11 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,26 @@
from typing import (
Any,
ClassVar,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Iterable,
Iterator,
Generator,
)
import zlib
import re
from contextlib import contextmanager
from contextlib import suppress

from dlt.common import pendulum, logger
from dlt.common.destination.capabilities import DataTypeMapper
from dlt.common.json import json
from dlt.common.schema.typing import (
C_DLT_LOAD_ID,
COLUMN_HINTS,
TColumnType,
TColumnSchemaBase,
TTableFormat,
)
from dlt.common.schema.utils import (
get_inherited_table_hint,
Expand All @@ -40,11 +38,11 @@
from dlt.common.storages import FileStorage
from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables
from dlt.common.schema import TColumnHint
from dlt.common.destination.reference import (
PreparedTableSchema,
StateInfo,
StorageSchemaInfo,
SupportsReadableDataset,
WithStateSync,
DestinationClientConfiguration,
DestinationClientDwhConfiguration,
Expand All @@ -55,9 +53,7 @@
JobClientBase,
HasFollowupJobs,
CredentialsConfiguration,
SupportsReadableRelation,
)
from dlt.destinations.dataset import ReadableDBAPIDataset

from dlt.destinations.exceptions import DatabaseUndefinedRelation
from dlt.destinations.job_impl import (
Expand Down Expand Up @@ -154,6 +150,8 @@ def __init__(
self.state_table_columns = ", ".join(
sql_client.escape_column_name(col) for col in state_table_["columns"]
)
self.active_hints: Dict[TColumnHint, str] = {}
self.type_mapper: DataTypeMapper = None
super().__init__(schema, config, sql_client.capabilities)
self.sql_client = sql_client
assert isinstance(config, DestinationClientDwhConfiguration)
Expand Down Expand Up @@ -569,6 +567,7 @@ def _get_table_update_sql(
# build CREATE
sql = self._make_create_table(qualified_name, table) + " (\n"
sql += ",\n".join([self._get_column_def_sql(c, table) for c in new_columns])
sql += self._get_constraints_sql(table_name, new_columns, generate_alter)
sql += ")"
sql_result.append(sql)
else:
Expand All @@ -582,8 +581,16 @@ def _get_table_update_sql(
sql_result.extend(
[sql_base + col_statement for col_statement in add_column_statements]
)
constraints_sql = self._get_constraints_sql(table_name, new_columns, generate_alter)
if constraints_sql:
sql_result.append(constraints_sql)
return sql_result

def _get_constraints_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> str:
return ""

def _check_table_update_hints(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> None:
Expand Down Expand Up @@ -613,12 +620,22 @@ def _check_table_update_hints(
" existing tables."
)

@abstractmethod
def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
pass
hints_ = self._get_column_hints_sql(c)
column_name = self.sql_client.escape_column_name(c["name"])
nullability = self._gen_not_null(c.get("nullable", True))
column_type = self.type_mapper.to_destination_type(c, table)

return f"{column_name} {column_type} {hints_} {nullability}"

def _get_column_hints_sql(self, c: TColumnSchema) -> str:
return " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True # use ColumnPropInfos to get default value
)

@staticmethod
def _gen_not_null(nullable: bool) -> str:
def _gen_not_null(self, nullable: bool) -> str:
return "NOT NULL" if not nullable else ""

def _create_table_update(
Expand Down
9 changes: 9 additions & 0 deletions docs/website/docs/dlt-ecosystem/destinations/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and
## Supported column hints
Snowflake supports the following [column hints](../../general-usage/schema#tables-and-columns):
* `cluster` - Creates a cluster column(s). Many columns per table are supported and only when a new table is created.
* `unique` - Creates UNIQUE hint on a Snowflake column, can be added to many columns. ([optional](#additional-destination-options))
* `primary_key` - Creates PRIMARY KEY on selected column(s), may be compound. ([optional](#additional-destination-options))

`unique` and `primary_key` are not enforced and `dlt` does not instruct Snowflake to `RELY` on them when
query planning.


## Table and column identifiers
Snowflake supports both case-sensitive and case-insensitive identifiers. All unquoted and uppercase identifiers resolve case-insensitively in SQL statements. Case-insensitive [naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) like the default **snake_case** will generate case-insensitive identifiers. Case-sensitive (like **sql_cs_v1**) will generate
Expand Down Expand Up @@ -308,13 +314,16 @@ pipeline = dlt.pipeline(
## Additional destination options

You can define your own stage to PUT files and disable the removal of the staged files after loading.
You can also opt-in to [create indexes](#supported-column-hints).

```toml
[destination.snowflake]
# Use an existing named stage instead of the default. Default uses the implicit table stage per table
stage_name="DLT_STAGE"
# Whether to keep or delete the staged files after COPY INTO succeeds
keep_staged_files=true
# Add UNIQUE and PRIMARY KEY hints to tables
create_indexes=true
```

### Setting up CSV format
Expand Down
Loading
Loading