From 2e671194a4352f0877ca5a0fd77e63ad3c6ac8a8 Mon Sep 17 00:00:00 2001 From: Willi Date: Mon, 22 Jul 2024 18:15:46 +0530 Subject: [PATCH] test that add_map can be used to transform items before the incremental function is called --- tests/extract/test_incremental.py | 62 +++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index fd7878e5f3..3e408a1eb7 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -43,6 +43,8 @@ ALL_TEST_DATA_ITEM_FORMATS, ) +import pyarrow as pa + @pytest.fixture(autouse=True) def switch_to_fifo(): @@ -829,6 +831,66 @@ def some_data(created_at=dlt.sources.incremental("created_at", on_cursor_value_n assert data_item_length(values) == 3 +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_set_default_value_for_incremental_cursor(item_type: TestDataItemFormat) -> None: + data = [ + {"id": 1, "created_at": 1, "updated_at": 1}, + {"id": 2, "created_at": 4, "updated_at": None}, + {"id": 3, "created_at": 3, "updated_at": 3}, + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data(created_at=dlt.sources.incremental("updated_at")): + yield source_items + + def set_default_updated_at(record): + if record.get("updated_at", None) is None: + record["updated_at"] = record.get("created_at", pendulum.now().int_timestamp) + return record + + def set_default_updated_at_pandas(df): + df["updated_at"] = df["updated_at"].fillna(df["created_at"]) + return df + + def set_default_updated_at_arrow(records): + updated_at_is_null = pa.compute.is_null(records.column("updated_at")) + updated_at_filled = pa.compute.if_else( + updated_at_is_null, records.column("created_at"), records.column("updated_at") + ) + if item_type == "arrow-table": + records = records.set_column( + records.schema.get_field_index("updated_at"), + pa.field("updated_at", records.column("updated_at").type), + updated_at_filled, + ) + elif item_type == "arrow-batch": + columns = [records.column(i) for i in range(records.num_columns)] + columns[2] = updated_at_filled + records = pa.RecordBatch.from_arrays(columns, schema=records.schema) + return records + + if item_type == "object": + func = set_default_updated_at + elif item_type == "pandas": + func = set_default_updated_at_pandas + elif item_type in ["arrow-table", "arrow-batch"]: + func = set_default_updated_at_arrow + + result = list(some_data().add_map(func, insert_at=1)) + values = data_item_to_list(item_type, result) + assert data_item_length(values) == 3 + assert values[1]["updated_at"] == 4 + + # same for pipeline run + p = dlt.pipeline(pipeline_name=uniq_id()) + p.extract(some_data().add_map(func, insert_at=1)) + s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][ + "updated_at" + ] + assert s["last_value"] == 4 + + def test_json_path_cursor() -> None: @dlt.resource def some_data(last_timestamp=dlt.sources.incremental("item.timestamp|modifiedAt")):