diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 6ee34bb758..99756b6fef 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -2878,11 +2878,10 @@ def test_push_table_with_upfront_schema() -> None: info = copy_pipeline.run(data, table_name="events", schema=copy_schema) assert copy_pipeline.default_schema.version_hash != infer_hash -def test_nested_inserts_correct_target() -> None: - @dlt.resource( - primary_key="id", - columns={"id": {"data_type": "bigint"}}, - ) + +@pytest.mark.parametrize("mark_main_resource", (True, False)) +def test_nested_inserts_correct_target(mark_main_resource: bool) -> None: + @dlt.resource() def my_resource(): yield [ { @@ -2891,7 +2890,7 @@ def my_resource(): {"id": "a", "value": 1}, {"id": "b", "value": 2}, {"id": "c", "value": 3}, - ] + ], }, { "id": 2000, @@ -2899,30 +2898,32 @@ def my_resource(): {"id": "a", "value": 4}, {"id": "b", "value": 5}, {"id": "c", "value": 6}, - ] + ], }, ] - @dlt.transformer( - data_from=my_resource, - write_disposition="replace", - # parallelized=True, - primary_key="id", - merge_key="id" - ) + @dlt.transformer(data_from=my_resource) def things( my_resources: List[TDataItem], ) -> Iterable[TDataItem]: for my_resource in my_resources: - fields: List[Dict] = my_resource.pop("fields") - yield my_resource + fields: List[TDataItem] = my_resource.pop("fields") + if mark_main_resource: + yield dlt.mark.with_hints( + item=my_resource, + hints=dlt.mark.make_hints( + table_name="things", + ) + ) + else: + yield my_resource + for field in fields: - #id = field.pop("id") - id = field["id"] - table_name = f"things_{id}" - field = { "my_resource_id": my_resource["id"] } | field + my_id = field["id"] + table_name = f"things_{my_id}" + field = {"my_resource_id": my_resource["id"]} | field yield dlt.mark.with_hints( item=field, hints=dlt.mark.make_hints( @@ -2932,11 +2933,8 @@ def things( ) @dlt.source() - def my_source( - ) -> Sequence[DltResource]: - return ( - things - ) + def my_source() -> Sequence[DltResource]: + return [things] pipeline_name = "pipe_" + uniq_id() pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") @@ -2944,11 +2942,4 @@ def my_source( assert_load_info(info) rows = load_tables_to_dicts(pipeline, "things_c", exclude_system_cols=True) print(rows) - assert_data_table_counts(pipeline, {"things": 1, "things_a": 1, "things_b": 1, "things_c": 1 }) - assert pipeline.last_trace.last_normalize_info.row_counts == { - "_dlt_pipeline_state": 1, - "things": 2, - "things_a": 2, - "things_b": 2, - "things_c": 2, - } + assert_data_table_counts(pipeline, {"things": 2, "things_a": 2, "things_b": 2, "things_c": 2 })