Skip to content

Commit

Permalink
test that add_map can be used to transform items before the increment…
Browse files Browse the repository at this point in the history
…al function is called
  • Loading branch information
willi-mueller committed Jul 22, 2024
1 parent 95ce2cb commit 2e67119
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions tests/extract/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
ALL_TEST_DATA_ITEM_FORMATS,
)

import pyarrow as pa


@pytest.fixture(autouse=True)
def switch_to_fifo():
Expand Down Expand Up @@ -829,6 +831,66 @@ def some_data(created_at=dlt.sources.incremental("created_at", on_cursor_value_n
assert data_item_length(values) == 3


@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS)
def test_set_default_value_for_incremental_cursor(item_type: TestDataItemFormat) -> None:
data = [
{"id": 1, "created_at": 1, "updated_at": 1},
{"id": 2, "created_at": 4, "updated_at": None},
{"id": 3, "created_at": 3, "updated_at": 3},
]
source_items = data_to_item_format(item_type, data)

@dlt.resource
def some_data(created_at=dlt.sources.incremental("updated_at")):
yield source_items

def set_default_updated_at(record):
if record.get("updated_at", None) is None:
record["updated_at"] = record.get("created_at", pendulum.now().int_timestamp)
return record

def set_default_updated_at_pandas(df):
df["updated_at"] = df["updated_at"].fillna(df["created_at"])
return df

def set_default_updated_at_arrow(records):
updated_at_is_null = pa.compute.is_null(records.column("updated_at"))
updated_at_filled = pa.compute.if_else(
updated_at_is_null, records.column("created_at"), records.column("updated_at")
)
if item_type == "arrow-table":
records = records.set_column(
records.schema.get_field_index("updated_at"),
pa.field("updated_at", records.column("updated_at").type),
updated_at_filled,
)
elif item_type == "arrow-batch":
columns = [records.column(i) for i in range(records.num_columns)]
columns[2] = updated_at_filled
records = pa.RecordBatch.from_arrays(columns, schema=records.schema)
return records

if item_type == "object":
func = set_default_updated_at
elif item_type == "pandas":
func = set_default_updated_at_pandas
elif item_type in ["arrow-table", "arrow-batch"]:
func = set_default_updated_at_arrow

result = list(some_data().add_map(func, insert_at=1))
values = data_item_to_list(item_type, result)
assert data_item_length(values) == 3
assert values[1]["updated_at"] == 4

# same for pipeline run
p = dlt.pipeline(pipeline_name=uniq_id())
p.extract(some_data().add_map(func, insert_at=1))
s = p.state["sources"][p.default_schema_name]["resources"]["some_data"]["incremental"][
"updated_at"
]
assert s["last_value"] == 4


def test_json_path_cursor() -> None:
@dlt.resource
def some_data(last_timestamp=dlt.sources.incremental("item.timestamp|modifiedAt")):
Expand Down

0 comments on commit 2e67119

Please sign in to comment.