From e19b85e32016b6b614023e55801aa276da96afe3 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 6 Mar 2024 22:23:32 +0100 Subject: [PATCH 1/4] rewrites incremental: computation of hashes vastly reduced, fixed wrong criteria when to deduplicate, unique index in arrow frames rarely created --- dlt/extract/incremental/__init__.py | 24 +++- dlt/extract/incremental/transform.py | 194 ++++++++++++++------------- 2 files changed, 118 insertions(+), 100 deletions(-) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index 54e8b3d447..bfc9542a1d 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -163,11 +163,12 @@ def _make_transforms(self) -> None: self._transformers[dt] = kls( self.resource_name, self.cursor_path, + self.initial_value, self.start_value, self.end_value, - self._cached_state, self.last_value_func, self._primary_key, + set(self._cached_state["unique_hashes"]), ) @classmethod @@ -453,14 +454,29 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: return rows transformer = self._get_transformer(rows) - if isinstance(rows, list): - return [ + rows = [ item for item in (self._transform_item(transformer, row) for row in rows) if item is not None ] - return self._transform_item(transformer, rows) + else: + rows = self._transform_item(transformer, rows) + + # write back state + self._cached_state["last_value"] = transformer.last_value + if self.primary_key != (): + # compute hashes for new last rows + unique_hashes = set( + transformer.compute_unique_value(row, self.primary_key) + for row in transformer.last_rows + ) + # add directly computed hashes + unique_hashes.update(transformer.unique_hashes) + print(unique_hashes) + self._cached_state["unique_hashes"] = list(unique_hashes) + + return rows Incremental.EMPTY = Incremental[Any]("") diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index e20617cf63..74b0966cee 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -1,17 +1,17 @@ from datetime import datetime, date # noqa: I251 -from typing import Any, Optional, Tuple, List +from typing import Any, Optional, Set, Tuple, List from dlt.common.exceptions import MissingDependencyException from dlt.common.utils import digest128 from dlt.common.json import json from dlt.common import pendulum -from dlt.common.typing import TDataItem, TDataItems -from dlt.common.jsonpath import TJsonPath, find_values, JSONPathFields, compile_path +from dlt.common.typing import TDataItem +from dlt.common.jsonpath import find_values, JSONPathFields, compile_path from dlt.extract.incremental.exceptions import ( IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, ) -from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc +from dlt.extract.incremental.typing import TCursorValue, LastValueFunc from dlt.extract.utils import resolve_column_value from dlt.extract.items import TTableHintTemplate from dlt.common.schema.typing import TColumnNames @@ -34,19 +34,24 @@ def __init__( self, resource_name: str, cursor_path: str, + initial_value: Optional[TCursorValue], start_value: Optional[TCursorValue], end_value: Optional[TCursorValue], - incremental_state: IncrementalColumnState, last_value_func: LastValueFunc[TCursorValue], primary_key: Optional[TTableHintTemplate[TColumnNames]], + unique_hashes: Set[str], ) -> None: self.resource_name = resource_name self.cursor_path = cursor_path + self.initial_value = initial_value self.start_value = start_value + self.last_value = start_value self.end_value = end_value - self.incremental_state = incremental_state + self.last_rows: List[TDataItem] = [] self.last_value_func = last_value_func self.primary_key = primary_key + self.unique_hashes = unique_hashes + self.start_unique_hashes = set(unique_hashes) # compile jsonpath self._compiled_cursor_path = compile_path(cursor_path) @@ -59,18 +64,10 @@ def __init__( self.cursor_path = self._compiled_cursor_path.fields[0] self._compiled_cursor_path = None - def __call__( - self, - row: TDataItem, - ) -> Tuple[bool, bool, bool]: ... - - -class JsonIncremental(IncrementalTransform): - def unique_value( + def compute_unique_value( self, row: TDataItem, primary_key: Optional[TTableHintTemplate[TColumnNames]], - resource_name: str, ) -> str: try: if primary_key: @@ -80,8 +77,15 @@ def unique_value( else: return None except KeyError as k_err: - raise IncrementalPrimaryKeyMissing(resource_name, k_err.args[0], row) + raise IncrementalPrimaryKeyMissing(self.resource_name, k_err.args[0], row) + def __call__( + self, + row: TDataItem, + ) -> Tuple[bool, bool, bool]: ... + + +class JsonIncremental(IncrementalTransform): def find_cursor_value(self, row: TDataItem) -> Any: """Finds value in row at cursor defined by self.cursor_path. @@ -113,7 +117,8 @@ def __call__( return row, False, False row_value = self.find_cursor_value(row) - last_value = self.incremental_state["last_value"] + last_value = self.last_value + last_value_func = self.last_value_func # For datetime cursor, ensure the value is a timezone aware datetime. # The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable @@ -128,41 +133,46 @@ def __call__( # Check whether end_value has been reached # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value if self.end_value is not None and ( - self.last_value_func((row_value, self.end_value)) != self.end_value - or self.last_value_func((row_value,)) == self.end_value + last_value_func((row_value, self.end_value)) != self.end_value + or last_value_func((row_value,)) == self.end_value ): return None, False, True check_values = (row_value,) + ((last_value,) if last_value is not None else ()) - new_value = self.last_value_func(check_values) + new_value = last_value_func(check_values) + # new_value is "less" or equal to last_value (the actual max) if last_value == new_value: - processed_row_value = self.last_value_func((row_value,)) - # we store row id for all records with the current "last_value" in state and use it to deduplicate - - if processed_row_value == last_value: - unique_value = self.unique_value(row, self.primary_key, self.resource_name) - # if unique value exists then use it to deduplicate - if unique_value: - if unique_value in self.incremental_state["unique_hashes"]: - return None, False, False - # add new hash only if the record row id is same as current last value - self.incremental_state["unique_hashes"].append(unique_value) - return row, False, False - # skip the record that is not a last_value or new_value: that record was already processed + # use func to compute row_value into last_value compatible + processed_row_value = last_value_func((row_value,)) + # skip the record that is not a start_value or new_value: that record was already processed check_values = (row_value,) + ( (self.start_value,) if self.start_value is not None else () ) - new_value = self.last_value_func(check_values) + new_value = last_value_func(check_values) # Include rows == start_value but exclude "lower" - if new_value == self.start_value and processed_row_value != self.start_value: - return None, True, False - else: - return row, False, False + # new_value is "less" or equal to start_value (the initial max) + if new_value == self.start_value: + # if equal there's still a chance that item gets in + if processed_row_value == self.start_value: + unique_value = self.compute_unique_value(row, self.primary_key) + # if unique value exists then use it to deduplicate + if unique_value: + if unique_value in self.start_unique_hashes: + return None, True, False + else: + # smaller than start value gets out + return None, True, False + + # we store row id for all records with the current "last_value" in state and use it to deduplicate + if processed_row_value == last_value: + # add new hash only if the record row id is same as current last value + self.last_rows.append(row) else: - self.incremental_state["last_value"] = new_value - unique_value = self.unique_value(row, self.primary_key, self.resource_name) - if unique_value: - self.incremental_state["unique_hashes"] = [unique_value] + self.last_value = new_value + # store rows with "max" values to compute hashes + # only when needed + self.last_rows = [row] + self.unique_hashes = set() return row, False, False @@ -170,19 +180,25 @@ def __call__( class ArrowIncremental(IncrementalTransform): _dlt_index = "_dlt_index" - def unique_values( - self, item: "TAnyArrowItem", unique_columns: List[str], resource_name: str + def compute_unique_values(self, item: "TAnyArrowItem", unique_columns: List[str]) -> List[str]: + if not unique_columns: + return [] + rows = item.select(unique_columns).to_pylist() + return [self.compute_unique_value(row, self.primary_key) for row in rows] + + def compute_unique_values_with_index( + self, item: "TAnyArrowItem", unique_columns: List[str] ) -> List[Tuple[int, str]]: if not unique_columns: return [] - item = item indices = item[self._dlt_index].to_pylist() rows = item.select(unique_columns).to_pylist() return [ - (index, digest128(json.dumps(row, sort_keys=True))) for index, row in zip(indices, rows) + (index, self.compute_unique_value(row, self.primary_key)) + for index, row in zip(indices, rows) ] - def _deduplicate( + def _add_unique_index( self, tbl: "pa.Table", unique_columns: Optional[List[str]], aggregate: str, cursor_path: str ) -> "pa.Table": """Creates unique index if necessary.""" @@ -222,8 +238,6 @@ def __call__( if not tbl: # row is None or empty arrow table return tbl, start_out_of_range, end_out_of_range - last_value = self.incremental_state["last_value"] - if self.last_value_func is max: compute = pa.compute.max aggregate = "max" @@ -267,64 +281,52 @@ def __call__( # NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary end_out_of_range = not end_compare(row_value_scalar, end_value_scalar).as_py() - if last_value is not None: - if self.start_value is not None: - # Remove rows lower than the last start value - keep_filter = last_value_compare( - tbl[cursor_path], to_arrow_scalar(self.start_value, cursor_data_type) - ) - start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) - tbl = tbl.filter(keep_filter) - + if self.start_value is not None: + start_value_scalar = to_arrow_scalar(self.start_value, cursor_data_type) + # Remove rows lower or equal than the last start value + keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar) + start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) + tbl = tbl.filter(keep_filter) # Deduplicate after filtering old values - last_value_scalar = to_arrow_scalar(last_value, cursor_data_type) - tbl = self._deduplicate(tbl, unique_columns, aggregate, cursor_path) - # Remove already processed rows where the cursor is equal to the last value - eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], last_value_scalar)) + tbl = self._add_unique_index(tbl, unique_columns, aggregate, cursor_path) + # Remove already processed rows where the cursor is equal to the start value + eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], start_value_scalar)) # compute index, unique hash mapping - unique_values = self.unique_values(eq_rows, unique_columns, self.resource_name) - unique_values = [ + unique_values_index = self.compute_unique_values_with_index(eq_rows, unique_columns) + unique_values_index = [ (i, uq_val) - for i, uq_val in unique_values - if uq_val in self.incremental_state["unique_hashes"] + for i, uq_val in unique_values_index + if uq_val in self.start_unique_hashes ] - remove_idx = pa.array(i for i, _ in unique_values) + # find rows with unique ids that were stored from previous run + remove_idx = pa.array(i for i, _ in unique_values_index) # Filter the table tbl = tbl.filter(pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx))) - if ( - new_value_compare(row_value_scalar, last_value_scalar).as_py() - and row_value != last_value - ): # Last value has changed - self.incremental_state["last_value"] = row_value - # Compute unique hashes for all rows equal to row value - self.incremental_state["unique_hashes"] = [ - uq_val - for _, uq_val in self.unique_values( + if ( + self.last_value is None + or new_value_compare( + row_value_scalar, to_arrow_scalar(self.last_value, cursor_data_type) + ).as_py() + ): # Last value has changed + self.last_value = row_value + # Compute unique hashes for all rows equal to row value + self.unique_hashes = set( + self.compute_unique_values( + tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)), + unique_columns, + ) + ) + elif self.last_value == row_value: + # last value is unchanged, add the hashes + self.unique_hashes.update( + set( + self.compute_unique_values( tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)), unique_columns, - self.resource_name, ) - ] - else: - # last value is unchanged, add the hashes - self.incremental_state["unique_hashes"] = list( - set( - self.incremental_state["unique_hashes"] - + [uq_val for _, uq_val in unique_values] - ) - ) - else: - tbl = self._deduplicate(tbl, unique_columns, aggregate, cursor_path) - self.incremental_state["last_value"] = row_value - self.incremental_state["unique_hashes"] = [ - uq_val - for _, uq_val in self.unique_values( - tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)), - unique_columns, - self.resource_name, ) - ] + ) if len(tbl) == 0: return None, start_out_of_range, end_out_of_range From 23198996eeb8bee04b9f4a4d530277eafb6dd514 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 6 Mar 2024 22:24:03 +0100 Subject: [PATCH 2/4] initial tests for ordered, random and overlapping incremental ranges --- tests/extract/test_incremental.py | 42 ++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index 7956c83947..b3db3b852b 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -1,5 +1,6 @@ import os import asyncio +import random from time import sleep from typing import Optional, Any from unittest import mock @@ -125,11 +126,11 @@ def test_unique_keys_are_deduplicated(item_type: TDataItemFormat) -> None: {"created_at": 3, "id": "e"}, ] data2 = [ + {"created_at": 4, "id": "g"}, {"created_at": 3, "id": "c"}, {"created_at": 3, "id": "d"}, {"created_at": 3, "id": "e"}, {"created_at": 3, "id": "f"}, - {"created_at": 4, "id": "g"}, ] source_items1 = data_to_item_format(item_type, data1) @@ -1307,7 +1308,6 @@ def descending_single_item( for i in reversed(range(14)): data = [{"updated_at": i}] yield from data_to_item_format(item_type, data) - yield {"updated_at": i} if i >= 10: assert updated_at.start_out_of_range is False else: @@ -1375,7 +1375,8 @@ def descending( assert data_item_length(data) == 48 - 10 + 1 # both bounds included -def test_transformer_row_order_out_of_range() -> None: +@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) +def test_transformer_row_order_out_of_range(item_type: TDataItemFormat) -> None: out_of_range = [] @dlt.transformer @@ -1387,13 +1388,14 @@ def descending( ) -> Any: for chunk in chunks(count(start=48, step=-1), 10): data = [{"updated_at": i, "package": package} for i in chunk] + # print(data) yield data_to_item_format("json", data) if updated_at.can_close(): out_of_range.append(package) return data = list([3, 2, 1] | descending) - assert len(data) == 48 - 10 + 1 + assert data_item_length(data) == 48 - 10 + 1 # we take full package 3 and then nothing in 1 and 2 assert len(out_of_range) == 3 @@ -1453,6 +1455,38 @@ def ascending_desc( assert data_item_length(data) == 45 - 22 +@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) +def test_unique_values_unordered_rows(item_type: TDataItemFormat) -> None: + @dlt.resource + def random_ascending_chunks( + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", + initial_value=10, + ) + ) -> Any: + range_ = list(range(updated_at.start_value, updated_at.start_value + 121)) + random.shuffle(range_) + for chunk in chunks(range_, 30): + # make sure that overlapping element is the last one + data = [{"updated_at": i} for i in chunk] + # random.shuffle(data) + print(data) + yield data_to_item_format(item_type, data) + + os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately + pipeline = dlt.pipeline("test_unique_values_unordered_rows", destination="dummy") + # print(list(random_ascending_chunks())) + pipeline.run(random_ascending_chunks()) + assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == 121 + + # 120 rows (one overlap - incremental reacquires and deduplicates) + # print(list(random_ascending_chunks())) + pipeline.run(random_ascending_chunks()) + assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == 120 + + # test next batch adding to unique + + @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) def test_get_incremental_value_type(item_type: TDataItemFormat) -> None: assert dlt.sources.incremental("id").get_incremental_value_type() is Any From 31c035c1af3c3a0700eea22b7dc82172897c0731 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 6 Mar 2024 22:24:22 +0100 Subject: [PATCH 3/4] clarifies what deduplication in incremental means --- .../code/zendesk-snippets.py | 2 +- .../docs/general-usage/incremental-loading.md | 19 ++++++++++++++----- tests/common/test_validation.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/docs/website/docs/examples/incremental_loading/code/zendesk-snippets.py b/docs/website/docs/examples/incremental_loading/code/zendesk-snippets.py index ff12a00fca..05ea18cb9e 100644 --- a/docs/website/docs/examples/incremental_loading/code/zendesk-snippets.py +++ b/docs/website/docs/examples/incremental_loading/code/zendesk-snippets.py @@ -140,4 +140,4 @@ def get_pages( # check that stuff was loaded row_counts = pipeline.last_trace.last_normalize_info.row_counts - assert row_counts["ticket_events"] == 17 \ No newline at end of file + assert row_counts["ticket_events"] == 17 diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index dd52c9c750..144b176332 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -454,11 +454,20 @@ def tickets( ``` ::: -### Deduplication primary_key - -`dlt.sources.incremental` will inherit the primary key that is set on the resource. - - let's you optionally set a `primary_key` that is used exclusively to +### Deduplicate overlapping ranges with primary key + +`Incremental` **does not** deduplicate datasets like **merge** write disposition does. It however +makes sure than when another portion of data is extracted, records that were previously loaded won't be +included again. `dlt` assumes that you load a range of data, where the lower bound is inclusive (ie. greater than equal). +This makes sure that you never lose any data but will also re-acquire some rows. +For example: you have a database table with an cursor field on `updated_at` which has a day resolution, then there's a high +chance that after you extract data on a given day, still more records will be added. When you extract on the next day, you +should reacquire data from the last day to make sure all records are present, this will however create overlap with data +from previous extract. + +By default, content hash (a hash of `json` representation of a row) will be used to deduplicate. +This may be slow so`dlt.sources.incremental` will inherit the primary key that is set on the resource. +You can optionally set a `primary_key` that is used exclusively to deduplicate and which does not become a table hint. The same setting lets you disable the deduplication altogether when empty tuple is passed. Below we pass `primary_key` directly to `incremental` to disable deduplication. That overrides `delta` primary_key set in the resource: diff --git a/tests/common/test_validation.py b/tests/common/test_validation.py index 3fff3bf2ea..f7773fb89c 100644 --- a/tests/common/test_validation.py +++ b/tests/common/test_validation.py @@ -273,3 +273,18 @@ def f(item: Union[TDataItem, TDynHintType]) -> TDynHintType: validate_dict( TTestRecordCallable, test_item, path=".", validator_f=lambda p, pk, pv, t: callable(pv) ) + + +# def test_union_merge() -> None: +# """Overriding fields is simply illegal in TypedDict""" +# class EndpointResource(TypedDict, total=False): +# name: TTableHintTemplate[str] + +# class TTestRecordNoName(EndpointResource, total=False): +# name: Optional[TTableHintTemplate[str]] + +# # test_item = {"name": None} +# # validate_dict(TTestRecordNoName, test_item, path=".") + +# test_item = {} +# validate_dict(TTestRecordNoName, test_item, path=".") From abc46b6f17cdf937e2a842b2d13ee92b570cfbf9 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 7 Mar 2024 22:21:00 +0100 Subject: [PATCH 4/4] handles no deduplication case explicitly, more tests --- dlt/extract/incremental/__init__.py | 3 +- dlt/extract/incremental/transform.py | 77 +++++++++------- tests/extract/test_incremental.py | 132 ++++++++++++++++++++++++--- 3 files changed, 162 insertions(+), 50 deletions(-) diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index bfc9542a1d..9ad174fd63 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -465,7 +465,7 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: # write back state self._cached_state["last_value"] = transformer.last_value - if self.primary_key != (): + if not transformer.deduplication_disabled: # compute hashes for new last rows unique_hashes = set( transformer.compute_unique_value(row, self.primary_key) @@ -473,7 +473,6 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: ) # add directly computed hashes unique_hashes.update(transformer.unique_hashes) - print(unique_hashes) self._cached_state["unique_hashes"] = list(unique_hashes) return rows diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 74b0966cee..2ad827b755 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -70,6 +70,11 @@ def compute_unique_value( primary_key: Optional[TTableHintTemplate[TColumnNames]], ) -> str: try: + assert not self.deduplication_disabled, ( + f"{self.resource_name}: Attempt to compute unique values when deduplication is" + " disabled" + ) + if primary_key: return digest128(json.dumps(resolve_column_value(primary_key, row), sort_keys=True)) elif primary_key is None: @@ -84,6 +89,11 @@ def __call__( row: TDataItem, ) -> Tuple[bool, bool, bool]: ... + @property + def deduplication_disabled(self) -> bool: + """Skip deduplication when length of the key is 0""" + return isinstance(self.primary_key, (list, tuple)) and len(self.primary_key) == 0 + class JsonIncremental(IncrementalTransform): def find_cursor_value(self, row: TDataItem) -> Any: @@ -154,13 +164,13 @@ def __call__( if new_value == self.start_value: # if equal there's still a chance that item gets in if processed_row_value == self.start_value: - unique_value = self.compute_unique_value(row, self.primary_key) - # if unique value exists then use it to deduplicate - if unique_value: + if not self.deduplication_disabled: + unique_value = self.compute_unique_value(row, self.primary_key) + # if unique value exists then use it to deduplicate if unique_value in self.start_unique_hashes: return None, True, False else: - # smaller than start value gets out + # smaller than start value: gets out return None, True, False # we store row id for all records with the current "last_value" in state and use it to deduplicate @@ -169,8 +179,7 @@ def __call__( self.last_rows.append(row) else: self.last_value = new_value - # store rows with "max" values to compute hashes - # only when needed + # store rows with "max" values to compute hashes after processing full batch self.last_rows = [row] self.unique_hashes = set() @@ -198,9 +207,7 @@ def compute_unique_values_with_index( for index, row in zip(indices, rows) ] - def _add_unique_index( - self, tbl: "pa.Table", unique_columns: Optional[List[str]], aggregate: str, cursor_path: str - ) -> "pa.Table": + def _add_unique_index(self, tbl: "pa.Table") -> "pa.Table": """Creates unique index if necessary.""" # create unique index if necessary if self._dlt_index not in tbl.schema.names: @@ -231,8 +238,6 @@ def __call__( self._dlt_index = primary_key elif primary_key is None: unique_columns = tbl.schema.names - else: # deduplicating is disabled - unique_columns = None start_out_of_range = end_out_of_range = False if not tbl: # row is None or empty arrow table @@ -240,13 +245,11 @@ def __call__( if self.last_value_func is max: compute = pa.compute.max - aggregate = "max" end_compare = pa.compute.less last_value_compare = pa.compute.greater_equal new_value_compare = pa.compute.greater elif self.last_value_func is min: compute = pa.compute.min - aggregate = "min" end_compare = pa.compute.greater last_value_compare = pa.compute.less_equal new_value_compare = pa.compute.less @@ -287,21 +290,24 @@ def __call__( keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar) start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) tbl = tbl.filter(keep_filter) - # Deduplicate after filtering old values - tbl = self._add_unique_index(tbl, unique_columns, aggregate, cursor_path) - # Remove already processed rows where the cursor is equal to the start value - eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], start_value_scalar)) - # compute index, unique hash mapping - unique_values_index = self.compute_unique_values_with_index(eq_rows, unique_columns) - unique_values_index = [ - (i, uq_val) - for i, uq_val in unique_values_index - if uq_val in self.start_unique_hashes - ] - # find rows with unique ids that were stored from previous run - remove_idx = pa.array(i for i, _ in unique_values_index) - # Filter the table - tbl = tbl.filter(pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx))) + if not self.deduplication_disabled: + # Deduplicate after filtering old values + tbl = self._add_unique_index(tbl) + # Remove already processed rows where the cursor is equal to the start value + eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], start_value_scalar)) + # compute index, unique hash mapping + unique_values_index = self.compute_unique_values_with_index(eq_rows, unique_columns) + unique_values_index = [ + (i, uq_val) + for i, uq_val in unique_values_index + if uq_val in self.start_unique_hashes + ] + # find rows with unique ids that were stored from previous run + remove_idx = pa.array(i for i, _ in unique_values_index) + # Filter the table + tbl = tbl.filter( + pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)) + ) if ( self.last_value is None @@ -310,14 +316,15 @@ def __call__( ).as_py() ): # Last value has changed self.last_value = row_value - # Compute unique hashes for all rows equal to row value - self.unique_hashes = set( - self.compute_unique_values( - tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)), - unique_columns, + if not self.deduplication_disabled: + # Compute unique hashes for all rows equal to row value + self.unique_hashes = set( + self.compute_unique_values( + tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)), + unique_columns, + ) ) - ) - elif self.last_value == row_value: + elif self.last_value == row_value and not self.deduplication_disabled: # last value is unchanged, add the hashes self.unique_hashes.update( set( diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index b3db3b852b..a393706de7 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -15,13 +15,14 @@ from dlt.common.configuration.specs.base_configuration import configspec, BaseConfiguration from dlt.common.configuration import ConfigurationValueError from dlt.common.pendulum import pendulum, timedelta -from dlt.common.pipeline import StateInjectableContext, resource_state +from dlt.common.pipeline import NormalizeInfo, StateInjectableContext, resource_state from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id, digest128, chunks from dlt.common.json import json from dlt.extract import DltSource from dlt.extract.exceptions import InvalidStepFunctionArguments +from dlt.extract.resource import DltResource from dlt.sources.helpers.transform import take_first from dlt.extract.incremental.exceptions import ( IncrementalCursorPathMissing, @@ -1456,35 +1457,140 @@ def ascending_desc( @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_unique_values_unordered_rows(item_type: TDataItemFormat) -> None: - @dlt.resource +@pytest.mark.parametrize("order", ["random", "desc", "asc"]) +@pytest.mark.parametrize("primary_key", [[], None, "updated_at"]) +@pytest.mark.parametrize( + "deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record") +) +def test_unique_values_unordered_rows( + item_type: TDataItemFormat, order: str, primary_key: Any, deterministic: bool +) -> None: + @dlt.resource(primary_key=primary_key) def random_ascending_chunks( + order: str, updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( "updated_at", initial_value=10, - ) + ), ) -> Any: range_ = list(range(updated_at.start_value, updated_at.start_value + 121)) - random.shuffle(range_) + if order == "random": + random.shuffle(range_) + if order == "desc": + range_ = reversed(range_) # type: ignore[assignment] + for chunk in chunks(range_, 30): # make sure that overlapping element is the last one - data = [{"updated_at": i} for i in chunk] + data = [ + {"updated_at": i, "rand": random.random() if not deterministic else 0} + for i in chunk + ] # random.shuffle(data) - print(data) yield data_to_item_format(item_type, data) os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately pipeline = dlt.pipeline("test_unique_values_unordered_rows", destination="dummy") - # print(list(random_ascending_chunks())) - pipeline.run(random_ascending_chunks()) + pipeline.run(random_ascending_chunks(order)) assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == 121 # 120 rows (one overlap - incremental reacquires and deduplicates) - # print(list(random_ascending_chunks())) - pipeline.run(random_ascending_chunks()) - assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == 120 + pipeline.run(random_ascending_chunks(order)) + # overlapping element must be deduped when: + # 1. we have primary key on just updated at + # OR we have a key on full record but the record is deterministic so duplicate may be found + rows = 120 if primary_key == "updated_at" or (deterministic and primary_key != []) else 121 + assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == rows + - # test next batch adding to unique +@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) +@pytest.mark.parametrize("primary_key", [[], None, "updated_at"]) # [], None, +@pytest.mark.parametrize( + "deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record") +) +def test_carry_unique_hashes( + item_type: TDataItemFormat, primary_key: Any, deterministic: bool +) -> None: + # each day extends list of hashes and removes duplicates until the last day + + @dlt.resource(primary_key=primary_key) + def random_ascending_chunks( + # order: str, + day: int, + updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( + "updated_at", + initial_value=10, + ), + ) -> Any: + range_ = random.sample( + range(updated_at.initial_value, updated_at.initial_value + 10), k=10 + ) # list(range(updated_at.initial_value, updated_at.initial_value + 10)) + range_ += [100] + if day == 4: + # on day 4 add an element that will reset all others + range_ += [1000] + + for chunk in chunks(range_, 3): + # make sure that overlapping element is the last one + data = [ + {"updated_at": i, "rand": random.random() if not deterministic else 0} + for i in chunk + ] + yield data_to_item_format(item_type, data) + + os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately + pipeline = dlt.pipeline("test_unique_values_unordered_rows", destination="dummy") + + def _assert_state(r_: DltResource, day: int, info: NormalizeInfo) -> None: + uniq_hashes = r_.state["incremental"]["updated_at"]["unique_hashes"] + row_count = info.row_counts.get("random_ascending_chunks", 0) + if primary_key == "updated_at": + # we keep only newest version of the record + assert len(uniq_hashes) == 1 + if day == 1: + # all records loaded + assert row_count == 11 + elif day == 4: + # new biggest item loaded + assert row_count == 1 + else: + # all deduplicated + assert row_count == 0 + elif primary_key is None: + # we deduplicate over full content + if day == 4: + assert len(uniq_hashes) == 1 + # both the 100 or 1000 are in if non deterministic content + assert row_count == (2 if not deterministic else 1) + else: + # each day adds new hash if content non deterministic + assert len(uniq_hashes) == (day if not deterministic else 1) + if day == 1: + assert row_count == 11 + else: + assert row_count == (1 if not deterministic else 0) + elif primary_key == []: + # no deduplication + assert len(uniq_hashes) == 0 + if day == 4: + assert row_count == 2 + else: + if day == 1: + assert row_count == 11 + else: + assert row_count == 1 + + r_ = random_ascending_chunks(1) + pipeline.run(r_) + _assert_state(r_, 1, pipeline.last_trace.last_normalize_info) + r_ = random_ascending_chunks(2) + pipeline.run(r_) + _assert_state(r_, 2, pipeline.last_trace.last_normalize_info) + r_ = random_ascending_chunks(3) + pipeline.run(r_) + _assert_state(r_, 3, pipeline.last_trace.last_normalize_info) + r_ = random_ascending_chunks(4) + pipeline.run(r_) + _assert_state(r_, 4, pipeline.last_trace.last_normalize_info) @pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS)