Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VioletM committed Aug 26, 2024
1 parent 227c89c commit 63927de
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 21 deletions.
8 changes: 4 additions & 4 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfigura

staging_config: Optional[DestinationClientStagingConfiguration] = None
"""configuration of the staging, if present, injected at runtime"""
truncate_table_before_load_on_staging_destination: bool = True
truncate_tables_on_staging_destination_before_load: bool = True
"""If dlt should truncate the tables on staging destination before loading data."""


Expand Down Expand Up @@ -580,8 +580,8 @@ class SupportsStagingDestination:
"""Adds capability to support a staging destination for the load"""

def __init__(self, config: DestinationClientDwhWithStagingConfiguration) -> None:
self.truncate_table_before_load_on_staging_destination = (
config.truncate_table_before_load_on_staging_destination
self.truncate_tables_on_staging_destination_before_load = (
config.truncate_tables_on_staging_destination_before_load
)

def should_load_data_to_staging_dataset_on_staging_destination(
Expand All @@ -591,7 +591,7 @@ def should_load_data_to_staging_dataset_on_staging_destination(

def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool:
# the default is to truncate the tables on the staging destination...
return self.truncate_table_before_load_on_staging_destination
return self.truncate_tables_on_staging_destination_before_load


# TODO: type Destination properly
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def should_truncate_table_before_load_on_staging_destination(self, table: TTable
if table["write_disposition"] == "replace" and not self._is_iceberg_table(
self.prepare_load_table(table["name"])
):
return self.truncate_table_before_load_on_staging_destination
return self.truncate_tables_on_staging_destination_before_load
return False

def should_load_data_to_staging_dataset_on_staging_destination(
Expand Down
4 changes: 1 addition & 3 deletions dlt/destinations/impl/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,7 @@ def __init__(
config.http_timeout,
config.retry_deadline,
)
SupportsStagingDestination.__init__(
self, config.truncate_table_before_load_on_staging_destination
)
SupportsStagingDestination.__init__(self, config)
super().__init__(schema, config, sql_client)
self.config: BigQueryClientConfiguration = config
self.sql_client: BigQuerySqlClient = sql_client # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/impl/dummy/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(
config: DummyClientConfiguration,
capabilities: DestinationCapabilitiesContext,
) -> None:
SupportsStagingDestination.__init__(self, config)
SupportsStagingDestination.__init__(self, config) # type: ignore
super().__init__(schema, config, capabilities)
self.in_staging_context = False
self.config: DummyClientConfiguration = config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import List, Dict, Any, Generator
import dlt


# Define a dlt resource with write disposition to 'merge'
@dlt.resource(name="parent_with_children", write_disposition={"disposition": "merge"})
def data_source() -> Generator[List[Dict[str, Any]], None, None]:
Expand All @@ -44,13 +45,15 @@ def data_source() -> Generator[List[Dict[str, Any]], None, None]:

yield data


# Function to add parent_id to each child record within a parent record
def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]:
parent_id_key = "parent_id"
for child in record["children"]:
child[parent_id_key] = record[parent_id_key]
return record


if __name__ == "__main__":
# Create and configure the dlt pipeline
pipeline = dlt.pipeline(
Expand All @@ -60,10 +63,6 @@ def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]:
)

# Run the pipeline
load_info = pipeline.run(
data_source()
.add_map(add_parent_id),
primary_key="parent_id"
)
load_info = pipeline.run(data_source().add_map(add_parent_id), primary_key="parent_id")
# Output the load information after pipeline execution
print(load_info)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import pytest

from tests.utils import skipifgithubfork
Expand Down Expand Up @@ -29,6 +28,7 @@
from typing import List, Dict, Any, Generator
import dlt


# Define a dlt resource with write disposition to 'merge'
@dlt.resource(name="parent_with_children", write_disposition={"disposition": "merge"})
def data_source() -> Generator[List[Dict[str, Any]], None, None]:
Expand All @@ -51,13 +51,15 @@ def data_source() -> Generator[List[Dict[str, Any]], None, None]:

yield data


# Function to add parent_id to each child record within a parent record
def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]:
parent_id_key = "parent_id"
for child in record["children"]:
child[parent_id_key] = record[parent_id_key]
return record


@skipifgithubfork
@pytest.mark.forked
def test_parent_child_relationship():
Expand All @@ -69,10 +71,6 @@ def test_parent_child_relationship():
)

# Run the pipeline
load_info = pipeline.run(
data_source()
.add_map(add_parent_id),
primary_key="parent_id"
)
load_info = pipeline.run(data_source().add_map(add_parent_id), primary_key="parent_id")
# Output the load information after pipeline execution
print(load_info)
2 changes: 1 addition & 1 deletion tests/load/test_dummy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def test_truncate_table_before_load_on_stanging(to_truncate) -> None:
load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES)
destination_client = load.get_destination_client(schema)
assert (
destination_client.should_truncate_table_before_load_on_staging_destination(
destination_client.should_truncate_table_before_load_on_staging_destination( # type: ignore
schema.tables["_dlt_version"]
)
== to_truncate
Expand Down

0 comments on commit 63927de

Please sign in to comment.