Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/1711 create with not exists dlt tables #1740

Merged
merged 6 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dlt/common/destination/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
# use naming convention in the schema
naming_convention: TNamingConventionReferenceArg = None
alter_add_multi_column: bool = True
create_table_not_exists: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to rename it that it's clear it for _dlt tables:
For ex, create_dlt_table_if_not_exists

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is in destination caps. but I will rename it to make clear what it means

supports_truncate_command: bool = True
schema_supports_numeric_precision: bool = True
timestamp_precision: int = 6
Expand Down
8 changes: 8 additions & 0 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,17 @@ def run_managed(
except (DestinationTerminalException, TerminalValueError) as e:
self._state = "failed"
self._exception = e
logger.exception(
f"Terminal exception in job {self.job_id()} on table {self.load_table_name} in file"
f" {self._file_path}"
)
except (DestinationTransientException, Exception) as e:
self._state = "retry"
self._exception = e
logger.exception(
f"Transient exception in job {self.job_id()} on table {self.load_table_name} in"
f" file {self._file_path}"
)
finally:
self._finished_at = pendulum.now()
# sanity check
Expand Down
12 changes: 5 additions & 7 deletions dlt/common/normalizers/json/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,10 @@ def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) ->
# and all child tables must be lists
return digest128(f"{parent_row_id}_{child_table}_{list_idx}", DLT_ID_LENGTH_BYTES)

@staticmethod
def _link_row(row: DictStrAny, parent_row_id: str, list_idx: int) -> DictStrAny:
def _link_row(self, row: DictStrAny, parent_row_id: str, list_idx: int) -> DictStrAny:
assert parent_row_id
row["_dlt_parent_id"] = parent_row_id
row["_dlt_list_idx"] = list_idx
row[self.c_dlt_parent_id] = parent_row_id
row[self.c_dlt_list_idx] = list_idx

return row

Expand Down Expand Up @@ -227,7 +226,7 @@ def _add_row_id(
if row_id_type == "row_hash":
row_id = DataItemNormalizer._get_child_row_hash(parent_row_id, table, pos)
# link to parent table
DataItemNormalizer._link_row(flattened_row, parent_row_id, pos)
self._link_row(flattened_row, parent_row_id, pos)

flattened_row[self.c_dlt_id] = row_id
return row_id
Expand Down Expand Up @@ -260,7 +259,6 @@ def _normalize_list(
parent_row_id: Optional[str] = None,
_r_lvl: int = 0,
) -> TNormalizedRowIterator:
v: DictStrAny = None
table = self.schema.naming.shorten_fragments(*parent_path, *ident_path)

for idx, v in enumerate(seq):
Expand All @@ -285,7 +283,7 @@ def _normalize_list(
child_row_hash = DataItemNormalizer._get_child_row_hash(parent_row_id, table, idx)
wrap_v = wrap_in_dict(v)
wrap_v[self.c_dlt_id] = child_row_hash
e = DataItemNormalizer._link_row(wrap_v, parent_row_id, idx)
e = self._link_row(wrap_v, parent_row_id, idx)
DataItemNormalizer._extend_row(extend, e)
yield (table, self.schema.naming.shorten_fragments(*parent_path)), e

Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def _get_table_update_sql(
partition_clause = self._iceberg_partition_clause(
cast(Optional[Dict[str, str]], table.get(PARTITION_HINT))
)
sql.append(f"""CREATE TABLE {qualified_table_name}
sql.append(f"""{self._make_create_table(qualified_table_name, table)}
({columns})
{partition_clause}
LOCATION '{location.rstrip('/')}'
Expand Down
5 changes: 4 additions & 1 deletion dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def update_stored_schema(
only_tables: Iterable[str] = None,
expected_update: TSchemaTables = None,
) -> TSchemaTables:
applied_update = super().update_stored_schema(only_tables, expected_update)
# create destination dirs for all tables
table_names = only_tables or self.schema.tables.keys()
dirs_to_create = self.get_table_dirs(table_names)
Expand All @@ -316,7 +317,9 @@ def update_stored_schema(
if not self.config.as_staging:
self._store_current_schema()

return expected_update
# we assume that expected_update == applied_update so table schemas in dest were not
# externally changed
return applied_update

def get_table_dir(self, table_name: str, remote: bool = False) -> str:
# dlt tables do not respect layout (for now)
Expand Down
1 change: 1 addition & 0 deletions dlt/destinations/impl/mssql/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext:
caps.max_text_data_type_length = 2**30 - 1
caps.is_max_text_data_type_length_in_bytes = False
caps.supports_ddl_transactions = True
caps.create_table_not_exists = False # IF NOT EXISTS not supported
caps.max_rows_per_insert = 1000
caps.timestamp_precision = 7
caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"]
Expand Down
2 changes: 2 additions & 0 deletions dlt/destinations/impl/synapse/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext:
caps.supports_transactions = True
caps.supports_ddl_transactions = False

caps.create_table_not_exists = False # IF NOT EXISTS on CREATE TABLE not supported

# Synapse throws "Some part of your SQL statement is nested too deeply. Rewrite the query or break it up into smaller queries."
# if number of records exceeds a certain number. Which exact number that is seems not deterministic:
# in tests, I've seen a query with 12230 records run succesfully on one run, but fail on a subsequent run, while the query remained exactly the same.
Expand Down
19 changes: 14 additions & 5 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,22 +522,31 @@ def _make_add_column_sql(
"""Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)"""
return [f"ADD COLUMN {self._get_column_def_sql(c, table_format)}" for c in new_columns]

def _make_create_table(self, qualified_name: str, table: TTableSchema) -> str:
not_exists_clause = " "
if (
table["name"] in self.schema.dlt_table_names()
and self.capabilities.create_table_not_exists
):
not_exists_clause = " IF NOT EXISTS "
return f"CREATE TABLE{not_exists_clause}{qualified_name}"

def _get_table_update_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
) -> List[str]:
# build sql
canonical_name = self.sql_client.make_qualified_table_name(table_name)
qualified_name = self.sql_client.make_qualified_table_name(table_name)
table = self.prepare_load_table(table_name)
table_format = table.get("table_format")
sql_result: List[str] = []
if not generate_alter:
# build CREATE
sql = f"CREATE TABLE {canonical_name} (\n"
sql = self._make_create_table(qualified_name, table) + " (\n"
sql += ",\n".join([self._get_column_def_sql(c, table_format) for c in new_columns])
sql += ")"
sql_result.append(sql)
else:
sql_base = f"ALTER TABLE {canonical_name}\n"
sql_base = f"ALTER TABLE {qualified_name}\n"
add_column_statements = self._make_add_column_sql(new_columns, table_format)
if self.capabilities.alter_add_multi_column:
column_sql = ",\n"
Expand All @@ -561,13 +570,13 @@ def _get_table_update_sql(
if hint == "not_null":
logger.warning(
f"Column(s) {hint_columns} with NOT NULL are being added to existing"
f" table {canonical_name}. If there's data in the table the operation"
f" table {qualified_name}. If there's data in the table the operation"
" will fail."
)
else:
logger.warning(
f"Column(s) {hint_columns} with hint {hint} are being added to existing"
f" table {canonical_name}. Several hint types may not be added to"
f" table {qualified_name}. Several hint types may not be added to"
" existing tables."
)
return sql_result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import List, Dict, Any, Generator
import dlt


# Define a dlt resource with write disposition to 'merge'
@dlt.resource(name="parent_with_children", write_disposition={"disposition": "merge"})
def data_source() -> Generator[List[Dict[str, Any]], None, None]:
Expand All @@ -44,13 +45,15 @@ def data_source() -> Generator[List[Dict[str, Any]], None, None]:

yield data


# Function to add parent_id to each child record within a parent record
def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]:
parent_id_key = "parent_id"
for child in record["children"]:
child[parent_id_key] = record[parent_id_key]
return record


if __name__ == "__main__":
# Create and configure the dlt pipeline
pipeline = dlt.pipeline(
Expand All @@ -60,10 +63,6 @@ def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]:
)

# Run the pipeline
load_info = pipeline.run(
data_source()
.add_map(add_parent_id),
primary_key="parent_id"
)
load_info = pipeline.run(data_source().add_map(add_parent_id), primary_key="parent_id")
# Output the load information after pipeline execution
print(load_info)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import pytest

from tests.utils import skipifgithubfork
Expand Down Expand Up @@ -29,6 +28,7 @@
from typing import List, Dict, Any, Generator
import dlt


# Define a dlt resource with write disposition to 'merge'
@dlt.resource(name="parent_with_children", write_disposition={"disposition": "merge"})
def data_source() -> Generator[List[Dict[str, Any]], None, None]:
Expand All @@ -51,13 +51,15 @@ def data_source() -> Generator[List[Dict[str, Any]], None, None]:

yield data


# Function to add parent_id to each child record within a parent record
def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]:
parent_id_key = "parent_id"
for child in record["children"]:
child[parent_id_key] = record[parent_id_key]
return record


@skipifgithubfork
@pytest.mark.forked
def test_parent_child_relationship():
Expand All @@ -69,10 +71,6 @@ def test_parent_child_relationship():
)

# Run the pipeline
load_info = pipeline.run(
data_source()
.add_map(add_parent_id),
primary_key="parent_id"
)
load_info = pipeline.run(data_source().add_map(add_parent_id), primary_key="parent_id")
# Output the load information after pipeline execution
print(load_info)
12 changes: 10 additions & 2 deletions tests/load/mssql/test_mssql_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def test_alter_table(client: MsSqlJobClient) -> None:
# existing table has no columns
sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)[0]
sqlfluff.parse(sql, dialect="tsql")
canonical_name = client.sql_client.make_qualified_table_name("event_test_table")
assert sql.count(f"ALTER TABLE {canonical_name}\nADD") == 1
qualified_name = client.sql_client.make_qualified_table_name("event_test_table")
assert sql.count(f"ALTER TABLE {qualified_name}\nADD") == 1
assert "event_test_table" in sql
assert '"col1" bigint NOT NULL' in sql
assert '"col2" float NOT NULL' in sql
Expand All @@ -75,3 +75,11 @@ def test_alter_table(client: MsSqlJobClient) -> None:
assert '"col6_precision" decimal(6,2) NOT NULL' in sql
assert '"col7_precision" varbinary(19)' in sql
assert '"col11_precision" time(3) NOT NULL' in sql


def test_create_dlt_table(client: MsSqlJobClient) -> None:
# non existing table
sql = client._get_table_update_sql("_dlt_version", TABLE_UPDATE, False)[0]
sqlfluff.parse(sql, dialect="tsql")
qualified_name = client.sql_client.make_qualified_table_name("_dlt_version")
assert f"CREATE TABLE {qualified_name}" in sql
11 changes: 10 additions & 1 deletion tests/load/postgres/test_postgres_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def test_create_table(client: PostgresClient) -> None:
# non existing table
sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0]
sqlfluff.parse(sql, dialect="postgres")
assert "event_test_table" in sql
qualified_name = client.sql_client.make_qualified_table_name("event_test_table")
assert f"CREATE TABLE {qualified_name}" in sql
assert '"col1" bigint NOT NULL' in sql
assert '"col2" double precision NOT NULL' in sql
assert '"col3" boolean NOT NULL' in sql
Expand Down Expand Up @@ -173,3 +174,11 @@ def test_create_table_case_sensitive(cs_client: PostgresClient) -> None:
# every line starts with "Col"
for line in sql.split("\n")[1:]:
assert line.startswith('"Col')


def test_create_dlt_table(client: PostgresClient) -> None:
# non existing table
sql = client._get_table_update_sql("_dlt_version", TABLE_UPDATE, False)[0]
sqlfluff.parse(sql, dialect="postgres")
qualified_name = client.sql_client.make_qualified_table_name("_dlt_version")
assert f"CREATE TABLE IF NOT EXISTS {qualified_name}" in sql
Loading