Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplifies and fixes incremental / fixes #971 #1062

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,12 @@ def _make_transforms(self) -> None:
self._transformers[dt] = kls(
self.resource_name,
self.cursor_path,
self.initial_value,
self.start_value,
self.end_value,
self._cached_state,
self.last_value_func,
self._primary_key,
set(self._cached_state["unique_hashes"]),
)

@classmethod
Expand Down Expand Up @@ -453,14 +454,28 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]:
return rows

transformer = self._get_transformer(rows)

if isinstance(rows, list):
return [
rows = [
item
for item in (self._transform_item(transformer, row) for row in rows)
if item is not None
]
return self._transform_item(transformer, rows)
else:
rows = self._transform_item(transformer, rows)

# write back state
self._cached_state["last_value"] = transformer.last_value
if not transformer.deduplication_disabled:
# compute hashes for new last rows
unique_hashes = set(
transformer.compute_unique_value(row, self.primary_key)
for row in transformer.last_rows
)
# add directly computed hashes
unique_hashes.update(transformer.unique_hashes)
self._cached_state["unique_hashes"] = list(unique_hashes)

return rows


Incremental.EMPTY = Incremental[Any]("")
Expand Down
221 changes: 115 additions & 106 deletions dlt/extract/incremental/transform.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from datetime import datetime, date # noqa: I251
from typing import Any, Optional, Tuple, List
from typing import Any, Optional, Set, Tuple, List

from dlt.common.exceptions import MissingDependencyException
from dlt.common.utils import digest128
from dlt.common.json import json
from dlt.common import pendulum
from dlt.common.typing import TDataItem, TDataItems
from dlt.common.jsonpath import TJsonPath, find_values, JSONPathFields, compile_path
from dlt.common.typing import TDataItem
from dlt.common.jsonpath import find_values, JSONPathFields, compile_path
from dlt.extract.incremental.exceptions import (
IncrementalCursorPathMissing,
IncrementalPrimaryKeyMissing,
)
from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc
from dlt.extract.incremental.typing import TCursorValue, LastValueFunc
from dlt.extract.utils import resolve_column_value
from dlt.extract.items import TTableHintTemplate
from dlt.common.schema.typing import TColumnNames
Expand All @@ -34,19 +34,24 @@ def __init__(
self,
resource_name: str,
cursor_path: str,
initial_value: Optional[TCursorValue],
start_value: Optional[TCursorValue],
end_value: Optional[TCursorValue],
incremental_state: IncrementalColumnState,
last_value_func: LastValueFunc[TCursorValue],
primary_key: Optional[TTableHintTemplate[TColumnNames]],
unique_hashes: Set[str],
) -> None:
self.resource_name = resource_name
self.cursor_path = cursor_path
self.initial_value = initial_value
self.start_value = start_value
self.last_value = start_value
self.end_value = end_value
self.incremental_state = incremental_state
self.last_rows: List[TDataItem] = []
self.last_value_func = last_value_func
self.primary_key = primary_key
self.unique_hashes = unique_hashes
self.start_unique_hashes = set(unique_hashes)

# compile jsonpath
self._compiled_cursor_path = compile_path(cursor_path)
Expand All @@ -59,29 +64,38 @@ def __init__(
self.cursor_path = self._compiled_cursor_path.fields[0]
self._compiled_cursor_path = None

def __call__(
self,
row: TDataItem,
) -> Tuple[bool, bool, bool]: ...


class JsonIncremental(IncrementalTransform):
def unique_value(
def compute_unique_value(
self,
row: TDataItem,
primary_key: Optional[TTableHintTemplate[TColumnNames]],
resource_name: str,
) -> str:
try:
assert not self.deduplication_disabled, (
f"{self.resource_name}: Attempt to compute unique values when deduplication is"
" disabled"
)

if primary_key:
return digest128(json.dumps(resolve_column_value(primary_key, row), sort_keys=True))
elif primary_key is None:
return digest128(json.dumps(row, sort_keys=True))
else:
return None
except KeyError as k_err:
raise IncrementalPrimaryKeyMissing(resource_name, k_err.args[0], row)
raise IncrementalPrimaryKeyMissing(self.resource_name, k_err.args[0], row)

def __call__(
self,
row: TDataItem,
) -> Tuple[bool, bool, bool]: ...

@property
def deduplication_disabled(self) -> bool:
"""Skip deduplication when length of the key is 0"""
return isinstance(self.primary_key, (list, tuple)) and len(self.primary_key) == 0


class JsonIncremental(IncrementalTransform):
def find_cursor_value(self, row: TDataItem) -> Any:
"""Finds value in row at cursor defined by self.cursor_path.

Expand Down Expand Up @@ -113,7 +127,8 @@ def __call__(
return row, False, False

row_value = self.find_cursor_value(row)
last_value = self.incremental_state["last_value"]
last_value = self.last_value
last_value_func = self.last_value_func

# For datetime cursor, ensure the value is a timezone aware datetime.
# The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable
Expand All @@ -128,63 +143,71 @@ def __call__(
# Check whether end_value has been reached
# Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value
if self.end_value is not None and (
self.last_value_func((row_value, self.end_value)) != self.end_value
or self.last_value_func((row_value,)) == self.end_value
last_value_func((row_value, self.end_value)) != self.end_value
or last_value_func((row_value,)) == self.end_value
):
return None, False, True

check_values = (row_value,) + ((last_value,) if last_value is not None else ())
new_value = self.last_value_func(check_values)
new_value = last_value_func(check_values)
# new_value is "less" or equal to last_value (the actual max)
if last_value == new_value:
processed_row_value = self.last_value_func((row_value,))
# we store row id for all records with the current "last_value" in state and use it to deduplicate

if processed_row_value == last_value:
unique_value = self.unique_value(row, self.primary_key, self.resource_name)
# if unique value exists then use it to deduplicate
if unique_value:
if unique_value in self.incremental_state["unique_hashes"]:
return None, False, False
# add new hash only if the record row id is same as current last value
self.incremental_state["unique_hashes"].append(unique_value)
return row, False, False
# skip the record that is not a last_value or new_value: that record was already processed
# use func to compute row_value into last_value compatible
processed_row_value = last_value_func((row_value,))
# skip the record that is not a start_value or new_value: that record was already processed
check_values = (row_value,) + (
(self.start_value,) if self.start_value is not None else ()
)
new_value = self.last_value_func(check_values)
new_value = last_value_func(check_values)
# Include rows == start_value but exclude "lower"
if new_value == self.start_value and processed_row_value != self.start_value:
return None, True, False
else:
return row, False, False
# new_value is "less" or equal to start_value (the initial max)
if new_value == self.start_value:
# if equal there's still a chance that item gets in
if processed_row_value == self.start_value:
if not self.deduplication_disabled:
unique_value = self.compute_unique_value(row, self.primary_key)
# if unique value exists then use it to deduplicate
if unique_value in self.start_unique_hashes:
return None, True, False
else:
# smaller than start value: gets out
return None, True, False

# we store row id for all records with the current "last_value" in state and use it to deduplicate
if processed_row_value == last_value:
# add new hash only if the record row id is same as current last value
self.last_rows.append(row)
else:
self.incremental_state["last_value"] = new_value
unique_value = self.unique_value(row, self.primary_key, self.resource_name)
if unique_value:
self.incremental_state["unique_hashes"] = [unique_value]
self.last_value = new_value
# store rows with "max" values to compute hashes after processing full batch
self.last_rows = [row]
self.unique_hashes = set()

return row, False, False


class ArrowIncremental(IncrementalTransform):
_dlt_index = "_dlt_index"

def unique_values(
self, item: "TAnyArrowItem", unique_columns: List[str], resource_name: str
def compute_unique_values(self, item: "TAnyArrowItem", unique_columns: List[str]) -> List[str]:
if not unique_columns:
return []
rows = item.select(unique_columns).to_pylist()
return [self.compute_unique_value(row, self.primary_key) for row in rows]

def compute_unique_values_with_index(
self, item: "TAnyArrowItem", unique_columns: List[str]
) -> List[Tuple[int, str]]:
if not unique_columns:
return []
item = item
indices = item[self._dlt_index].to_pylist()
rows = item.select(unique_columns).to_pylist()
return [
(index, digest128(json.dumps(row, sort_keys=True))) for index, row in zip(indices, rows)
(index, self.compute_unique_value(row, self.primary_key))
for index, row in zip(indices, rows)
]

def _deduplicate(
self, tbl: "pa.Table", unique_columns: Optional[List[str]], aggregate: str, cursor_path: str
) -> "pa.Table":
def _add_unique_index(self, tbl: "pa.Table") -> "pa.Table":
"""Creates unique index if necessary."""
# create unique index if necessary
if self._dlt_index not in tbl.schema.names:
Expand Down Expand Up @@ -215,24 +238,18 @@ def __call__(
self._dlt_index = primary_key
elif primary_key is None:
unique_columns = tbl.schema.names
else: # deduplicating is disabled
unique_columns = None

start_out_of_range = end_out_of_range = False
if not tbl: # row is None or empty arrow table
return tbl, start_out_of_range, end_out_of_range

last_value = self.incremental_state["last_value"]

if self.last_value_func is max:
compute = pa.compute.max
aggregate = "max"
end_compare = pa.compute.less
last_value_compare = pa.compute.greater_equal
new_value_compare = pa.compute.greater
elif self.last_value_func is min:
compute = pa.compute.min
aggregate = "min"
end_compare = pa.compute.greater
last_value_compare = pa.compute.less_equal
new_value_compare = pa.compute.less
Expand Down Expand Up @@ -267,64 +284,56 @@ def __call__(
# NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary
end_out_of_range = not end_compare(row_value_scalar, end_value_scalar).as_py()

if last_value is not None:
if self.start_value is not None:
# Remove rows lower than the last start value
keep_filter = last_value_compare(
tbl[cursor_path], to_arrow_scalar(self.start_value, cursor_data_type)
if self.start_value is not None:
start_value_scalar = to_arrow_scalar(self.start_value, cursor_data_type)
# Remove rows lower or equal than the last start value
keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar)
start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py())
tbl = tbl.filter(keep_filter)
if not self.deduplication_disabled:
# Deduplicate after filtering old values
tbl = self._add_unique_index(tbl)
# Remove already processed rows where the cursor is equal to the start value
eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], start_value_scalar))
# compute index, unique hash mapping
unique_values_index = self.compute_unique_values_with_index(eq_rows, unique_columns)
unique_values_index = [
(i, uq_val)
for i, uq_val in unique_values_index
if uq_val in self.start_unique_hashes
]
# find rows with unique ids that were stored from previous run
remove_idx = pa.array(i for i, _ in unique_values_index)
# Filter the table
tbl = tbl.filter(
pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx))
)
start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py())
tbl = tbl.filter(keep_filter)

# Deduplicate after filtering old values
last_value_scalar = to_arrow_scalar(last_value, cursor_data_type)
tbl = self._deduplicate(tbl, unique_columns, aggregate, cursor_path)
# Remove already processed rows where the cursor is equal to the last value
eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], last_value_scalar))
# compute index, unique hash mapping
unique_values = self.unique_values(eq_rows, unique_columns, self.resource_name)
unique_values = [
(i, uq_val)
for i, uq_val in unique_values
if uq_val in self.incremental_state["unique_hashes"]
]
remove_idx = pa.array(i for i, _ in unique_values)
# Filter the table
tbl = tbl.filter(pa.compute.invert(pa.compute.is_in(tbl[self._dlt_index], remove_idx)))

if (
new_value_compare(row_value_scalar, last_value_scalar).as_py()
and row_value != last_value
): # Last value has changed
self.incremental_state["last_value"] = row_value

if (
self.last_value is None
or new_value_compare(
row_value_scalar, to_arrow_scalar(self.last_value, cursor_data_type)
).as_py()
): # Last value has changed
self.last_value = row_value
if not self.deduplication_disabled:
# Compute unique hashes for all rows equal to row value
self.incremental_state["unique_hashes"] = [
uq_val
for _, uq_val in self.unique_values(
self.unique_hashes = set(
self.compute_unique_values(
tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)),
unique_columns,
self.resource_name,
)
]
else:
# last value is unchanged, add the hashes
self.incremental_state["unique_hashes"] = list(
set(
self.incremental_state["unique_hashes"]
+ [uq_val for _, uq_val in unique_values]
)
)
else:
tbl = self._deduplicate(tbl, unique_columns, aggregate, cursor_path)
self.incremental_state["last_value"] = row_value
self.incremental_state["unique_hashes"] = [
uq_val
for _, uq_val in self.unique_values(
tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)),
unique_columns,
self.resource_name,
elif self.last_value == row_value and not self.deduplication_disabled:
# last value is unchanged, add the hashes
self.unique_hashes.update(
set(
self.compute_unique_values(
tbl.filter(pa.compute.equal(tbl[cursor_path], row_value_scalar)),
unique_columns,
)
)
]
)

if len(tbl) == 0:
return None, start_out_of_range, end_out_of_range
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,4 @@ def get_pages(

# check that stuff was loaded
row_counts = pipeline.last_trace.last_normalize_info.row_counts
assert row_counts["ticket_events"] == 17
assert row_counts["ticket_events"] == 17
Loading
Loading