Skip to content

Commit

Permalink
Update tests destination config
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Sep 11, 2024
1 parent 874c871 commit a69d749
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 14 deletions.
4 changes: 3 additions & 1 deletion tests/load/pipeline/test_arrow_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def some_data():

# use csv for postgres to get native arrow processing
destination_config.file_format = (
destination_config.file_format if destination_config.destination_type != "postgres" else "csv"
destination_config.file_format
if destination_config.destination_type != "postgres"
else "csv"
)

load_info = pipeline.run(some_data(), **destination_config.run_kwargs)
Expand Down
5 changes: 4 additions & 1 deletion tests/load/pipeline/test_merge_disposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,10 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration)
assert_load_info(info)
# make sure it was parquet or sql inserts
files = p.get_load_package_info(p.list_completed_load_packages()[1]).jobs["completed_jobs"]
if destination_config.destination == "athena" and destination_config.table_format == "iceberg":
if (
destination_config.destination_type == "athena"
and destination_config.table_format == "iceberg"
):
# iceberg uses sql to copy tables
expected_formats.append("sql")
assert all(f.job_file_info.file_format in expected_formats for f in files)
Expand Down
2 changes: 1 addition & 1 deletion tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def some_source():
if destination_config.supports_merge:
expected_completed_jobs += 1
# add iceberg copy jobs
if destination_config.destination == "athena":
if destination_config.destination_type == "athena":
expected_completed_jobs += 2 # if destination_config.supports_merge else 4
assert len(package_info.jobs["completed_jobs"]) == expected_completed_jobs

Expand Down
9 changes: 6 additions & 3 deletions tests/load/pipeline/test_stage_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati
with staging_client:
# except Athena + Iceberg which does not store tables in staging dataset
if (
destination_config.destination == "athena"
destination_config.destination_type == "athena"
and destination_config.table_format == "iceberg"
):
table_count = 0
Expand All @@ -257,7 +257,7 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati
_, staging_client = pipeline._get_destination_clients(pipeline.default_schema)
with staging_client:
# except for Athena which does not delete staging destination tables
if destination_config.destination == "athena":
if destination_config.destination_type == "athena":
if destination_config.table_format == "iceberg":
table_count = 0
else:
Expand Down Expand Up @@ -302,7 +302,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non
):
# Redshift can't load fixed width binary columns from parquet
exclude_columns.append("col7_precision")
if destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl":
if (
destination_config.destination_type == "databricks"
and destination_config.file_format == "jsonl"
):
exclude_types.extend(["decimal", "binary", "wei", "json", "date"])
exclude_columns.append("col1_precision")

Expand Down
4 changes: 2 additions & 2 deletions tests/load/sources/filesystem/test_filesystem_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_csv_transformers(

# print(pipeline.last_trace.last_normalize_info)
# must contain 24 rows of A881
if not destination_config.destination == "filesystem":
if not destination_config.destination_type == "filesystem":
# TODO: comment out when filesystem destination supports queries (data pond PR)
assert_query_data(pipeline, "SELECT code FROM met_csv", ["A881"] * 24)

Expand All @@ -138,7 +138,7 @@ def test_csv_transformers(
assert_load_info(load_info)
# print(pipeline.last_trace.last_normalize_info)
# must contain 48 rows of A803
if not destination_config.destination == "filesystem":
if not destination_config.destination_type == "filesystem":
# TODO: comment out when filesystem destination supports queries (data pond PR)
assert_query_data(pipeline, "SELECT code FROM met_csv", ["A803"] * 48)
# and 48 rows in total -> A881 got replaced
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def test_load_sql_schema_loads_all_tables(
schema=sql_source_db.schema,
backend=backend,
reflection_level="minimal",
type_adapter_callback=default_test_callback(destination_config.destination, backend),
type_adapter_callback=default_test_callback(destination_config.destination_type, backend),
)

if destination_config.destination == "bigquery" and backend == "connectorx":
if destination_config.destination_type == "bigquery" and backend == "connectorx":
# connectorx generates nanoseconds time which bigquery cannot load
source.has_precision.add_map(convert_time_to_us)
source.has_precision_nullable.add_map(convert_time_to_us)
Expand Down Expand Up @@ -91,10 +91,10 @@ def test_load_sql_schema_loads_all_tables_parallel(
schema=sql_source_db.schema,
backend=backend,
reflection_level="minimal",
type_adapter_callback=default_test_callback(destination_config.destination, backend),
type_adapter_callback=default_test_callback(destination_config.destination_type, backend),
).parallelize()

if destination_config.destination == "bigquery" and backend == "connectorx":
if destination_config.destination_type == "bigquery" and backend == "connectorx":
# connectorx generates nanoseconds time which bigquery cannot load
source.has_precision.add_map(convert_time_to_us)
source.has_precision_nullable.add_map(convert_time_to_us)
Expand Down
4 changes: 2 additions & 2 deletions tests/load/test_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,9 @@ def test_load_with_all_types(
client.update_stored_schema()

if isinstance(client, WithStagingDataset):
should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) # type: ignore[attr-defined]
should_load_to_staging = client.should_load_data_to_staging_dataset(table_name)
if should_load_to_staging:
with client.with_staging_dataset(): # type: ignore[attr-defined]
with client.with_staging_dataset():
# create staging for merge dataset
client.initialize_storage()
client.update_stored_schema()
Expand Down

0 comments on commit a69d749

Please sign in to comment.