Skip to content

Commit

Permalink
fixes tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Sep 1, 2024
1 parent 4cd3dfa commit 9111fa2
Show file tree
Hide file tree
Showing 13 changed files with 49 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_destination_athena.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ env:
RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }}
ACTIVE_DESTINATIONS: "[\"athena\"]"
ALL_FILESYSTEM_DRIVERS: "[\"memory\"]"
EXCLUDED_DESTINATION_CONFIGURATIONS: "[\"athena-parquet-staging-iceberg\", \"athena-parquet-no-staging-iceberg\"]"
EXCLUDED_DESTINATION_CONFIGURATIONS: "[\"athena-parquet-iceberg-no-staging-iceberg\", \"athena-parquet-iceberg-staging-iceberg\"]"

jobs:
get_docs_changes:
Expand Down
13 changes: 7 additions & 6 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,21 +461,22 @@ def drop_storage(self) -> None:

def verify_schema(self, only_tables: Iterable[str] = None) -> List[PreparedTableSchema]:
"""Verifies schema before loading, returns a list of verified loaded tables."""
load_tables = [
self.prepare_load_table(table_name)
for table_name in (only_tables or self.schema.data_table_names(seen_data_only=True))
]
if exceptions := verify_schema_capabilities(
self.schema,
load_tables,
self.capabilities,
self.config.destination_type,
warnings=False,
):
for exception in exceptions:
logger.error(str(exception))
raise exceptions[0]
return load_tables
# verify all tables with data
return [
self.prepare_load_table(table_name)
for table_name in set(
list(only_tables) + self.schema.data_table_names(seen_data_only=True)
)
]

def update_stored_schema(
self,
Expand Down
3 changes: 1 addition & 2 deletions dlt/common/destination/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

def verify_schema_capabilities(
schema: Schema,
load_tables: Sequence[TTableSchema],
capabilities: DestinationCapabilitiesContext,
destination_type: str,
warnings: bool = True,
Expand Down Expand Up @@ -56,7 +55,7 @@ def verify_schema_capabilities(
)

# check for any table clashes
for table in load_tables:
for table in schema.data_tables():
table_name = table["name"]
# detect table name conflict
cased_table_name = case_identifier(table_name)
Expand Down
6 changes: 3 additions & 3 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,14 +659,14 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None:
)

def verify_schema(self, only_tables: Iterable[str] = None) -> List[PreparedTableSchema]:
loaded_table = super().verify_schema(only_tables)
loaded_tables = super().verify_schema(only_tables)
if exceptions := verify_schema_merge_disposition(
self.schema, loaded_table, self.capabilities, warnings=True
self.schema, loaded_tables, self.capabilities, warnings=True
):
for exception in exceptions:
logger.error(str(exception))
raise exceptions[0]
return loaded_table
return loaded_tables

def prepare_load_job_execution(self, job: RunnableLoadJob) -> None:
self._set_query_tags_for_job(load_id=job._load_id, table=job._load_table)
Expand Down
2 changes: 1 addition & 1 deletion tests/common/cases/schemas/eth/ethereum_schema_v9.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
version: 17
version_hash: oHfYGTI2GHOxuzwVz6+yvMilXUvHYhxrxkanC2T6MAI=
version_hash: PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4=
engine_version: 9
name: ethereum
tables:
Expand Down
2 changes: 1 addition & 1 deletion tests/load/pipeline/test_dremio.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def items() -> Iterator[Any]:
"sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}],
}

print(pipeline.run([items]), **destination_config.run_kwargs)
print(pipeline.run([items], **destination_config.run_kwargs))

table_counts = load_table_counts(
pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()]
Expand Down
7 changes: 6 additions & 1 deletion tests/load/pipeline/test_merge_disposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,12 @@ def data():
blocks__transactions = schema_.tables["blocks__transactions"]
blocks__transactions["write_disposition"] = "merge"
blocks__transactions["x-merge-strategy"] = merge_strategy # type: ignore[typeddict-unknown-key]
blocks__transactions["table_format"] = destination_config.table_format

blocks__transactions__logs = schema_.tables["blocks__transactions__logs"]
blocks__transactions__logs["write_disposition"] = "merge"
blocks__transactions__logs["x-merge-strategy"] = merge_strategy # type: ignore[typeddict-unknown-key]
blocks__transactions__logs["table_format"] = destination_config.table_format

return data

Expand Down Expand Up @@ -1179,8 +1181,11 @@ def r():
yield {"foo": "bar"}

assert "scd2" not in p.destination.capabilities().supported_merge_strategies
with pytest.raises(DestinationCapabilitiesException):
with pytest.raises(PipelineStepFailed) as pip_ex:
p.run(r())
assert pip_ex.value.step == "normalize" # failed already in normalize when generating row ids
# PipelineStepFailed -> NormalizeJobFailed -> DestinationCapabilitiesException
assert isinstance(pip_ex.value.__cause__.__cause__, DestinationCapabilitiesException)


@pytest.mark.parametrize(
Expand Down
6 changes: 3 additions & 3 deletions tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def test_default_pipeline_names(
assert p.pipeline_name in possible_names
assert p.pipelines_dir == os.path.abspath(os.path.join(TEST_STORAGE_ROOT, ".dlt", "pipelines"))
assert p.dataset_name in possible_dataset_names
# assert p.destination is None
# assert p.default_schema_name is None
assert p.destination is None
assert p.default_schema_name is None

data = ["a", "b", "c"]
with pytest.raises(PipelineStepFailed) as step_ex:
Expand All @@ -90,7 +90,7 @@ def data_fun() -> Iterator[Any]:
assert p.default_schema_name in ["dlt_pytest", "dlt", "dlt_jb_pytest_runner"]

# this will create additional schema
p.extract(data_fun(), schema=dlt.Schema("names"))
p.extract(data_fun(), schema=dlt.Schema("names"), table_format=destination_config.table_format)
assert p.default_schema_name in ["dlt_pytest", "dlt", "dlt_jb_pytest_runner"]
assert "names" in p.schemas.keys()

Expand Down
18 changes: 10 additions & 8 deletions tests/load/pipeline/test_scd2.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,15 +557,17 @@ def test_validity_column_name_conflict(destination_config: DestinationTestConfig
def r(data):
yield data

# configuring a validity column name that appears in the data should cause an exception
dim_snap = {"nk": 1, "foo": 1, "from": 1} # conflict on "from" column
with pytest.raises(PipelineStepFailed) as pip_ex:
p.run(r(dim_snap), **destination_config.run_kwargs)
assert isinstance(pip_ex.value.__context__.__context__, ColumnNameConflictException)
# a schema check against an items got dropped because it was very costly and done on each row
dim_snap = {"nk": 1, "foo": 1, "from": "X"} # conflict on "from" column
p.run(r(dim_snap), **destination_config.run_kwargs)
dim_snap = {"nk": 1, "foo": 1, "to": 1} # conflict on "to" column
with pytest.raises(PipelineStepFailed):
p.run(r(dim_snap), **destination_config.run_kwargs)
assert isinstance(pip_ex.value.__context__.__context__, ColumnNameConflictException)
p.run(r(dim_snap), **destination_config.run_kwargs)

# instead the variant columns got generated
dim_test_table = p.default_schema.tables["dim_test"]
assert "from__v_text" in dim_test_table["columns"]

# but `to` column was coerced and then overwritten, this is the cost of dropping the check


@pytest.mark.parametrize(
Expand Down
6 changes: 4 additions & 2 deletions tests/load/pipeline/test_stage_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None:
assert len(package_info.jobs["failed_jobs"]) == 0
# we have 4 parquet and 4 reference jobs plus one merge job
num_jobs = 4 + 4
# sql job is used to copy parquet to Athena Iceberg table (_dlt_pipeline_state)
num_sql_jobs = 1
num_sql_jobs = 0
if destination_config.supports_merge:
num_sql_jobs += 1
# sql job is used to copy parquet to Athena Iceberg table (_dlt_pipeline_state)
if destination_config.destination == "athena":
num_sql_jobs += 1
assert len(package_info.jobs["completed_jobs"]) == num_jobs + num_sql_jobs
assert (
len(
Expand Down
10 changes: 8 additions & 2 deletions tests/load/test_dummy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,14 @@ def test_init_client_truncate_tables() -> None:
"event_bot",
}

replace_ = lambda table: table["write_disposition"] == "replace"
merge_ = lambda table: table["write_disposition"] == "merge"
replace_ = (
lambda table_name: client.prepare_load_table(table_name)["write_disposition"]
== "replace"
)
merge_ = (
lambda table_name: client.prepare_load_table(table_name)["write_disposition"]
== "merge"
)

# set event_bot chain to merge
bot_chain = get_nested_tables(schema.tables, "event_bot")
Expand Down
2 changes: 2 additions & 0 deletions tests/load/weaviate/test_weaviate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_case_sensitive_properties_create(client: WeaviateClient) -> None:
)
client.schema._bump_version()
with pytest.raises(SchemaIdentifierNormalizationCollision) as clash_ex:
client.verify_schema()
client.update_stored_schema()
assert clash_ex.value.identifier_type == "column"
assert clash_ex.value.identifier_name == "coL1"
Expand Down Expand Up @@ -170,6 +171,7 @@ def test_case_sensitive_properties_add(client: WeaviateClient) -> None:
)
client.schema._bump_version()
with pytest.raises(SchemaIdentifierNormalizationCollision):
client.verify_schema()
client.update_stored_schema()

# _, table_columns = client.get_storage_table("ColClass")
Expand Down
3 changes: 2 additions & 1 deletion tests/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dlt
from dlt.common import json, sleep
from dlt.common.configuration.utils import auto_cast
from dlt.common.destination.exceptions import DestinationUndefinedEntity
from dlt.common.pipeline import LoadInfo
from dlt.common.schema.utils import get_table_format
Expand Down Expand Up @@ -147,7 +148,7 @@ def _load_file(client: FSClientBase, filepath) -> List[Dict[str, Any]]:
cols = lines[0][15:-2].split(",")
for line in lines[2:]:
if line:
values = line[1:-3].split(",")
values = map(auto_cast, line[1:-3].split(","))
result.append(dict(zip(cols, values)))

# load parquet
Expand Down

0 comments on commit 9111fa2

Please sign in to comment.