From c0e4795b2c6854238df7741f03e067fb789ce442 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee <34739235+Pipboyguy@users.noreply.github.com> Date: Wed, 6 Nov 2024 06:39:50 +0200 Subject: [PATCH] LanceDB - Remove Orphaned Chunks (#1620) * Add tests for LanceDB chunking and merging functionality Signed-off-by: Marcel Coetzee * Add TSplitter type alias for LanceDB document splitting function Signed-off-by: Marcel Coetzee * Refine typing for chunks Signed-off-by: Marcel Coetzee * Add type definitions for chunk splitter function and related types Signed-off-by: Marcel Coetzee * Remove unused ChunkInputT, ChunkOutputT, and TSplitter type definitions Signed-off-by: Marcel Coetzee * Implement efficient update strategy for chunked documents in LanceDB Signed-off-by: Marcel Coetzee * Implement efficient update strategy for chunked documents in LanceDB Signed-off-by: Marcel Coetzee * Refactor LanceDB client and tests for improved readability and type safety Signed-off-by: Marcel Coetzee * Linting Signed-off-by: Marcel Coetzee * Add document_id parameter to lancedb_adapter and update merge logic Signed-off-by: Marcel Coetzee * Remove resolved comments Signed-off-by: Marcel Coetzee * Implement efficient orphan removal for chunked documents in LanceDB Signed-off-by: Marcel Coetzee * Implement efficient update strategy for chunked documents in LanceDB Signed-off-by: Marcel Coetzee * Add test for removing orphaned records in LanceDB Signed-off-by: Marcel Coetzee * Update LanceDB orphaned records removal test for chunked documents Signed-off-by: Marcel Coetzee * Set test pipeline as dev mode Signed-off-by: Marcel Coetzee * Fix write disposition check in LanceDBRemoveOrphansJob execute method Signed-off-by: Marcel Coetzee * Add FollowupJob trait to LoadLanceDBJob Signed-off-by: Marcel Coetzee * Fix file type Signed-off-by: Marcel Coetzee * Fix file typing Signed-off-by: Marcel Coetzee * Add test for removing orphaned records in LanceDB root table Signed-off-by: Marcel Coetzee * Enhance LanceDB test to cover nested child removal and update scenarios Signed-off-by: Marcel Coetzee * Use doc id hint for top level tables Signed-off-by: Marcel Coetzee * Only join on join columns for orphan removal job Signed-off-by: Marcel Coetzee * Add ollama to supported embedding providers and test orphaned record removal with embeddings Signed-off-by: Marcel Coetzee * Add merge_key to document resource for efficient updates in LanceDB Signed-off-by: Marcel Coetzee * Formatting Signed-off-by: Marcel Coetzee * Set default file size to 128MB Signed-off-by: Marcel Coetzee * Only use parquet loader file formats Signed-off-by: Marcel Coetzee * Import pyarrow.parquet Signed-off-by: Marcel Coetzee * Remove recommended file size from LanceDB destination capabilities Signed-off-by: Marcel Coetzee * Update LanceDB client to use more efficient batch processing methods on loading for Load Jobs Signed-off-by: Marcel Coetzee * Refactor unique identifier handling for LanceDB tables Signed-off-by: Marcel Coetzee * Optimize UUID column generation for LanceDB tables Signed-off-by: Marcel Coetzee * Refactor LanceDBClient to use string type hints for Table Signed-off-by: Marcel Coetzee * Minor refactor Signed-off-by: Marcel Coetzee * Implement efficient schema update with Nullability support Signed-off-by: Marcel Coetzee * Optimize orphaned chunks removal for large datasets Signed-off-by: Marcel Coetzee * Projection pushdown Signed-off-by: Marcel Coetzee * Format Signed-off-by: Marcel Coetzee * Prevent primary key and document ID hint conflict in merge disposition Signed-off-by: Marcel Coetzee * Add recommended file size for LanceDB destination Signed-off-by: Marcel Coetzee * Improve comment clarity for projection push-down in LanceDB Signed-off-by: Marcel Coetzee * Update to new load interface Signed-off-by: Marcel Coetzee * Remove unnecessary LanceDBLoadJob attributes Signed-off-by: Marcel Coetzee * Change instance attributes to `run` method as variables Signed-off-by: Marcel Coetzee * Schedule follow up refernce job Signed-off-by: Marcel Coetzee * Add follow up lancedb remove orphan job skeleron Signed-off-by: Marcel Coetzee * Write empty follow up file Signed-off-by: Marcel Coetzee * Write parquet Signed-off-by: Marcel Coetzee * Add support for reference file format in LanceDB destination Signed-off-by: Marcel Coetzee * Handle parent table name resolution if it doesn't exist in Lance db remove orphan job Signed-off-by: Marcel Coetzee * Refactor specialised orphan follow up job back to reference job Signed-off-by: Marcel Coetzee * Refactor orphan removal for chunked documents Signed-off-by: Marcel Coetzee * Fix dlt system table check for name instead of object Signed-off-by: Marcel Coetzee * Implement staging methods Signed-off-by: Marcel Coetzee * Override staging client methods Signed-off-by: Marcel Coetzee * Docs Signed-off-by: Marcel Coetzee * Override staging client methods Signed-off-by: Marcel Coetzee * Delete with inserts Signed-off-by: Marcel Coetzee * Keep with batch reader Signed-off-by: Marcel Coetzee * Remove Lancedb client's staging implementation Signed-off-by: Marcel Coetzee * Insert in memory arrow table. This will be optimized Signed-off-by: Marcel Coetzee * Rename classes to the new job implementation classes Signed-off-by: Marcel Coetzee * Use namedtuple for table chain to improve readability Signed-off-by: Marcel Coetzee * Remove orphans by loading all ancestor IDs simultaneously Signed-off-by: Marcel Coetzee * Fix doc_id adapter Signed-off-by: Marcel Coetzee * Fix doc_id adapter Signed-off-by: Marcel Coetzee * Revert to previous Signed-off-by: Marcel Coetzee * Revert "Remove orphans by loading all ancestor IDs simultaneously" This reverts commit 06e04d9fe6fbb76095129d93f5872c3c42566ca8. * Remove doc_id hint Signed-off-by: Marcel Coetzee * Infer merge key if not supplied from provided primary key Signed-off-by: Marcel Coetzee * Remove unused utility functions Signed-off-by: Marcel Coetzee * Remove LanceDB doc ID hints and use schema normalizer Signed-off-by: Marcel Coetzee * LanceDB writes strange code Signed-off-by: Marcel Coetzee * Minor Formatting Signed-off-by: Marcel Coetzee * Support compound primary and merge keys Signed-off-by: Marcel Coetzee * Remove old comment Signed-off-by: Marcel Coetzee * - Change default vector column name to "vector" to conform with lancedb standard - Add search tests with tantivy as search engine Signed-off-by: Marcel Coetzee * Format and fix linting Signed-off-by: Marcel Coetzee * Add custom embedding function registration test Signed-off-by: Marcel Coetzee * Spawn process in test to make sure registry can be deserialized from arrow files Signed-off-by: Marcel Coetzee * Simplify null string handling Signed-off-by: Marcel Coetzee * Change NULL string replacement with random string, doc clarification Signed-off-by: Marcel Coetzee * Update default vector column name in docs Signed-off-by: Marcel Coetzee * Set `remove_orphans` flag to False on tests that don't require it Signed-off-by: Marcel Coetzee * Implement starter arrow string placeholder function Signed-off-by: Marcel Coetzee * Add test for empty arrow string element vectorised replacement utility function Signed-off-by: Marcel Coetzee * Handle NULL values in addition to empty strings in arrow substitution method Signed-off-by: Marcel Coetzee * More efficient empty value replacement with canonical arrow usage Signed-off-by: Marcel Coetzee * Format Signed-off-by: Marcel Coetzee * Bump pyarrow version Signed-off-by: Marcel Coetzee * Use pa.nulls instead of [None]*len Signed-off-by: Marcel Coetzee * Update tests Signed-off-by: Marcel Coetzee * Invert remove orphans flag Signed-off-by: Marcel Coetzee * Implement root table orphan deletion, only integer doc_ids Signed-off-by: Marcel Coetzee * Cater for string ids as well in doc_id removal process Signed-off-by: Marcel Coetzee * Fix test with wrong primary key Signed-off-by: Marcel Coetzee * Just send list of ids as is. don't pc.compute on client end Signed-off-by: Marcel Coetzee * Extract schema matching into utils Signed-off-by: Marcel Coetzee * Add utils Signed-off-by: Marcel Coetzee * Pass all tests Signed-off-by: Marcel Coetzee * Minor format and cleanup Signed-off-by: Marcel Coetzee * Docs Signed-off-by: Marcel Coetzee * Amend replace test to test with large number of records to catch race conditions with replace disposition Signed-off-by: Marcel Coetzee * Fix replace race conditions by delegating truncation to dlt Signed-off-by: Marcel Coetzee * Update lock file Signed-off-by: Marcel Coetzee * Refactor type mapping and schema handling in LanceDB client Signed-off-by: Marcel Coetzee * Change 'complex' column type to 'json' in LanceDB client Signed-off-by: Marcel Coetzee * update lock file Signed-off-by: Marcel Coetzee * fixes generating lancedb literals * verifies merge key early, fixes column override in adapters * fixes linting errors --------- Signed-off-by: Marcel Coetzee Co-authored-by: Marcin Rudolf --- dlt/common/data_writers/escape.py | 17 + .../impl/databricks/databricks.py | 2 +- .../impl/databricks/sql_client.py | 4 +- .../impl/lancedb/configuration.py | 3 +- dlt/destinations/impl/lancedb/factory.py | 8 +- .../impl/lancedb/lancedb_adapter.py | 27 +- .../impl/lancedb/lancedb_client.py | 427 ++++++++++++------ dlt/destinations/impl/lancedb/schema.py | 32 +- dlt/destinations/impl/lancedb/utils.py | 80 ++-- .../impl/qdrant/qdrant_adapter.py | 3 +- .../impl/weaviate/weaviate_adapter.py | 1 + dlt/destinations/sql_jobs.py | 1 - .../dlt-ecosystem/destinations/lancedb.md | 49 +- poetry.lock | 108 ++--- pyproject.toml | 2 +- tests/load/lancedb/test_merge.py | 425 +++++++++++++++++ tests/load/lancedb/test_pipeline.py | 197 +++++--- tests/load/lancedb/test_utils.py | 46 ++ tests/load/lancedb/utils.py | 13 +- tests/load/utils.py | 2 +- 20 files changed, 1152 insertions(+), 295 deletions(-) create mode 100644 tests/load/lancedb/test_merge.py create mode 100644 tests/load/lancedb/test_utils.py diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 06c8d7a95a..393e9e8508 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -79,6 +79,23 @@ def escape_duckdb_literal(v: Any) -> Any: return str(v) +def escape_lancedb_literal(v: Any) -> Any: + if isinstance(v, str): + # we escape extended string which behave like the redshift string + return _escape_extended(v, prefix="'") + if isinstance(v, (datetime, date, time)): + return f"'{v.isoformat()}'" + if isinstance(v, (list, dict)): + return _escape_extended(json.dumps(v), prefix="'") + # TODO: check how binaries are represented in fusion + if isinstance(v, bytes): + return f"from_base64('{base64.b64encode(v).decode('ascii')}')" + if v is None: + return "NULL" + + return str(v) + + MS_SQL_ESCAPE_DICT = { "'": "''", "\n": "' + CHAR(10) + N'", diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 75bd8ffa13..718427af87 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -224,7 +224,7 @@ def __init__( ) super().__init__(schema, config, sql_client) self.config: DatabricksClientConfiguration = config - self.sql_client: DatabricksSqlClient = sql_client # type: ignore[assignment] + self.sql_client: DatabricksSqlClient = sql_client self.type_mapper = self.capabilities.get_type_mapper() def create_load_job( diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 88d47410d5..8bff4e0d73 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -41,7 +41,7 @@ class DatabricksCursorImpl(DBApiCursorImpl): """Use native data frame support if available""" - native_cursor: DatabricksSqlCursor # type: ignore[assignment] + native_cursor: DatabricksSqlCursor vector_size: ClassVar[int] = 2048 # vector size is 2048 def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: @@ -144,7 +144,7 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB # db_args = kwargs or None db_args = args or kwargs or None - with self._conn.cursor() as curr: # type: ignore[assignment] + with self._conn.cursor() as curr: curr.execute(query, db_args) yield DatabricksCursorImpl(curr) # type: ignore[abstract] diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index 329132f495..8f6a192bb0 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -59,6 +59,7 @@ class LanceDBClientOptions(BaseConfiguration): "sentence-transformers", "huggingface", "colbert", + "ollama", ] @@ -92,8 +93,6 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): Make sure it corresponds with the associated embedding model's dimensionality.""" vector_field_name: str = "vector" """Name of the special field to store the vector embeddings.""" - id_field_name: str = "id__" - """Name of the special field to manage deduplication.""" sentinel_table_name: str = "dltSentinelTable" """Name of the sentinel table that encapsulates datasets. Since LanceDB has no concept of schemas, this table serves as a proxy to group related dlt tables together.""" diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index 8ce2217007..d0d22ed3fb 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -26,8 +26,8 @@ class lancedb(Destination[LanceDBClientConfiguration, "LanceDBClient"]): def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl"] + caps.preferred_loader_file_format = "parquet" + caps.supported_loader_file_formats = ["parquet", "reference"] caps.type_mapper = LanceDBTypeMapper caps.max_identifier_length = 200 @@ -42,6 +42,10 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.timestamp_precision = 6 caps.supported_replace_strategies = ["truncate-and-insert"] + caps.recommended_file_size = 128_000_000 + + caps.supported_merge_strategies = ["upsert"] + return caps @property diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 99d5ef43c6..4314dd703f 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -1,16 +1,20 @@ -from typing import Any +from typing import Any, Dict from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource +from dlt.extract.items import TTableHintTemplate VECTORIZE_HINT = "x-lancedb-embed" +NO_REMOVE_ORPHANS_HINT = "x-lancedb-remove-orphans" def lancedb_adapter( data: Any, embed: TColumnNames = None, + merge_key: TColumnNames = None, + no_remove_orphans: bool = False, ) -> DltResource: """Prepares data for the LanceDB destination by specifying which columns should be embedded. @@ -20,6 +24,10 @@ def lancedb_adapter( object. embed (TColumnNames, optional): Specify columns to generate embeddings for. It can be a single column name as a string, or a list of column names. + merge_key (TColumnNames, optional): Specify columns to merge on. + It can be a single column name as a string, or a list of column names. + no_remove_orphans (bool): Specify whether to remove orphaned records in child + tables with no parent records after merges to maintain referential integrity. Returns: DltResource: A resource with applied LanceDB-specific hints. @@ -34,7 +42,8 @@ def lancedb_adapter( """ resource = get_resource_for_adapter(data) - column_hints: TTableSchemaColumns = {} + additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} + column_hints: TTableSchemaColumns = None if embed: if isinstance(embed, str): @@ -43,6 +52,7 @@ def lancedb_adapter( raise ValueError( "'embed' must be a list of column names or a single column name as a string." ) + column_hints = {} for column_name in embed: column_hints[column_name] = { @@ -50,9 +60,16 @@ def lancedb_adapter( VECTORIZE_HINT: True, # type: ignore[misc] } - if not column_hints: - raise ValueError("A value for 'embed' must be specified.") + additional_table_hints[NO_REMOVE_ORPHANS_HINT] = no_remove_orphans + + if column_hints or additional_table_hints or merge_key: + resource.apply_hints( + merge_key=merge_key, columns=column_hints, additional_table_hints=additional_table_hints + ) else: - resource.apply_hints(columns=column_hints) + raise ValueError( + "You must must provide at least either the 'embed' or 'merge_key' or 'remove_orphans'" + " argument if using the adapter." + ) return resource diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8a347989a0..1a3e1a7d34 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,4 +1,3 @@ -import uuid from types import TracebackType from typing import ( List, @@ -12,15 +11,17 @@ Dict, Sequence, TYPE_CHECKING, + Set, ) -from dlt.common.destination.capabilities import DataTypeMapper import lancedb # type: ignore +import lancedb.table # type: ignore import pyarrow as pa +import pyarrow.parquet as pq from lancedb import DBConnection +from lancedb.common import DATA # type: ignore from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore -from lancedb.table import Table # type: ignore from numpy import ndarray from pyarrow import Array, ChunkedArray, ArrowInvalid @@ -39,53 +40,142 @@ StorageSchemaInfo, StateInfo, LoadJob, + HasFollowupJobs, + FollowupJobRequest, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import ( - C_DLT_LOAD_ID, + TColumnType, TTableSchemaColumns, TWriteDisposition, + TColumnSchema, + TTableSchema, ) -from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage -from dlt.common.typing import DictStrAny +from dlt.common.schema.utils import get_columns_names_with_prop, is_nested_table +from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, ) from dlt.destinations.impl.lancedb.exceptions import ( lancedb_error, ) -from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + VECTORIZE_HINT, + NO_REMOVE_ORPHANS_HINT, +) from dlt.destinations.impl.lancedb.schema import ( make_arrow_field_schema, make_arrow_table_schema, TArrowSchema, NULL_SCHEMA, TArrowField, + arrow_datatype_to_fusion_datatype, + TTableLineage, + TableJob, ) from dlt.destinations.impl.lancedb.utils import ( - list_merge_identifiers, - generate_uuid, set_non_standard_providers_environment_variables, + EMPTY_STRING_PLACEHOLDER, + fill_empty_source_column_values_with_placeholder, + get_canonical_vector_database_doc_id_merge_key, + create_filter_condition, ) +from dlt.destinations.job_impl import ReferenceFollowupJobRequest +from dlt.destinations.type_mapping import TypeMapperImpl if TYPE_CHECKING: NDArray = ndarray[Any, Any] else: NDArray = ndarray -EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" +TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +BATCH_PROCESS_CHUNK_SIZE = 10_000 + + +class LanceDBTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "text": pa.string(), + "double": pa.float64(), + "bool": pa.bool_(), + "bigint": pa.int64(), + "binary": pa.binary(), + "date": pa.date32(), + "json": pa.string(), + } + + sct_to_dbt = {} + + dbt_to_sct = { + pa.string(): "text", + pa.float64(): "double", + pa.bool_(): "bool", + pa.int64(): "bigint", + pa.binary(): "binary", + pa.date32(): "date", + } + + def to_db_decimal_type(self, column: TColumnSchema) -> pa.Decimal128Type: + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) + return pa.decimal128(precision, scale) + + def to_db_datetime_type( + self, + column: TColumnSchema, + table: TTableSchema = None, + ) -> pa.TimestampType: + column_name = column.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + if timezone is not None or precision is not None: + logger.warning( + "LanceDB does not currently support column flags for timezone or precision." + f" These flags were used in column '{column_name}'." + ) + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.timestamp(unit, "UTC") + + def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> pa.Time64Type: + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.time64(unit) + + def from_db_type( + self, + db_type: pa.DataType, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> TColumnType: + if isinstance(db_type, pa.TimestampType): + return dict( + data_type="timestamp", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Time64Type): + return dict( + data_type="time", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Decimal128Type): + precision, scale = db_type.precision, db_type.scale + if (precision, scale) == self.capabilities.wei_precision: + return cast(TColumnType, dict(data_type="wei")) + return dict(data_type="decimal", precision=precision, scale=scale) + return super().from_db_type(cast(str, db_type), precision, scale) # type: ignore -def upload_batch( - records: List[DictStrAny], +def write_records( + records: DATA, /, *, db_client: DBConnection, table_name: str, - write_disposition: TWriteDisposition, - id_field_name: Optional[str] = None, + write_disposition: Optional[TWriteDisposition] = "append", + merge_key: Optional[str] = None, + remove_orphans: Optional[bool] = False, + filter_condition: Optional[str] = None, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -93,8 +183,11 @@ def upload_batch( records: The data to be inserted as payload. db_client: The LanceDB client connection. table_name: The name of the table to insert into. - id_field_name: The name of the ID field for update/merge operations. + merge_key: Keys for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. + remove_orphans (bool): Whether to remove orphans after insertion or not (only merge disposition). + filter_condition (str): If None, then all such rows will be deleted. + Otherwise, the condition will be used as an SQL filter to limit what rows are deleted. Raises: ValueError: If the write disposition is unsupported, or `id_field_name` is not @@ -110,16 +203,17 @@ def upload_batch( ) from e try: - if write_disposition in ("append", "skip"): + if write_disposition in ("append", "skip", "replace"): tbl.add(records) - elif write_disposition == "replace": - tbl.add(records, mode="overwrite") elif write_disposition == "merge": - if not id_field_name: - raise ValueError("To perform a merge update, 'id_field_name' must be specified.") - tbl.merge_insert( - id_field_name - ).when_matched_update_all().when_not_matched_insert_all().execute(records) + if remove_orphans: + tbl.merge_insert(merge_key).when_not_matched_by_source_delete( + filter_condition + ).execute(records) + else: + tbl.merge_insert( + merge_key + ).when_matched_update_all().when_not_matched_insert_all().execute(records) else: raise DestinationTerminalException( f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" @@ -135,6 +229,8 @@ class LanceDBClient(JobClientBase, WithStateSync): """LanceDB destination handler.""" model_func: TextEmbeddingFunction + """The embedder callback used for each chunk.""" + dataset_name: str def __init__( self, @@ -152,6 +248,7 @@ def __init__( self.registry = EmbeddingFunctionRegistry.get_instance() self.type_mapper = self.capabilities.get_type_mapper() self.sentinel_table_name = config.sentinel_table_name + self.dataset_name = self.config.normalize_dataset_name(self.schema) embedding_model_provider = self.config.embedding_model_provider @@ -169,11 +266,6 @@ def __init__( ) self.vector_field_name = self.config.vector_field_name - self.id_field_name = self.config.id_field_name - - @property - def dataset_name(self) -> str: - return self.config.normalize_dataset_name(self.schema) @property def sentinel_table(self) -> str: @@ -187,7 +279,7 @@ def make_qualified_table_name(self, table_name: str) -> str: ) def get_table_schema(self, table_name: str) -> TArrowSchema: - schema_table: Table = self.db_client.open_table(table_name) + schema_table: "lancedb.table.Table" = self.db_client.open_table(table_name) schema_table.checkout_latest() schema = schema_table.schema return cast( @@ -196,13 +288,15 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + def create_table( + self, table_name: str, schema: TArrowSchema, mode: str = "create" + ) -> "lancedb.table.Table": """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: schema: The table schema to create. table_name: The name of the table to create. - mode (): The mode to use when creating the table. Can be either "create" or "overwrite". + mode (str): The mode to use when creating the table. Can be either "create" or "overwrite". By default, if the table already exists, an exception is raised. If you want to overwrite the table, use mode="overwrite". """ @@ -230,7 +324,7 @@ def query_table( Returns: A LanceDB query builder. """ - query_table: Table = self.db_client.open_table(table_name) + query_table: "lancedb.table.Table" = self.db_client.open_table(table_name) query_table.checkout_latest() return query_table.search(query=query) @@ -255,7 +349,7 @@ def drop_storage(self) -> None: Deletes all tables in the dataset and all data, as well as sentinel table associated with them. - If the dataset name was not provided, it deletes all the tables in the current schema. + If the dataset name wasn't provided, it deletes all the tables in the current schema. """ for table_name in self._get_table_names(): self.db_client.drop_table(table_name) @@ -282,7 +376,22 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: def is_storage_initialized(self) -> bool: return self.table_exists(self.sentinel_table) - def _create_sentinel_table(self) -> Table: + def verify_schema( + self, only_tables: Iterable[str] = None, new_jobs: Iterable[ParsedLoadJobFileName] = None + ) -> List[PreparedTableSchema]: + loaded_tables = super().verify_schema(only_tables, new_jobs) + # verify merge keys early + for load_table in loaded_tables: + if not is_nested_table(load_table) and not load_table.get(NO_REMOVE_ORPHANS_HINT): + if merge_key := get_columns_names_with_prop(load_table, "merge_key"): + if len(merge_key) > 1: + raise DestinationTerminalException( + "You cannot specify multiple merge keys with LanceDB orphan remove" + f" enabled: {merge_key}" + ) + return loaded_tables + + def _create_sentinel_table(self) -> "lancedb.table.Table": """Create an empty table to indicate that the storage is initialized.""" return self.create_table(schema=NULL_SCHEMA, table_name=self.sentinel_table) @@ -310,7 +419,7 @@ def update_stored_schema( # TODO: return a real updated table schema (like in SQL job client) self._execute_schema_update(only_tables) else: - logger.info( + logger.debug( f"Schema with hash {self.schema.stored_version_hash} " f"inserted at {schema_info.inserted_at} found " "in storage, no upgrade required" @@ -325,7 +434,7 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] try: fq_table_name = self.make_qualified_table_name(table_name) - table: Table = self.db_client.open_table(fq_table_name) + table: "lancedb.table.Table" = self.db_client.open_table(fq_table_name) table.checkout_latest() arrow_schema: TArrowSchema = table.schema except FileNotFoundError: @@ -341,34 +450,33 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] return True, table_schema @lancedb_error - def add_table_fields( - self, table_name: str, field_schemas: List[TArrowField] - ) -> Optional[Table]: - """Add multiple fields to the LanceDB table at once. + def extend_lancedb_table_schema(self, table_name: str, field_schemas: List[pa.Field]) -> None: + """Extend LanceDB table schema with empty columns. Args: - table_name: The name of the table to create the fields on. - field_schemas: The list of fields to create. + table_name: The name of the table to create the fields on. + field_schemas: The list of PyArrow Fields to create in the target LanceDB table. """ - table: Table = self.db_client.open_table(table_name) + table: "lancedb.table.Table" = self.db_client.open_table(table_name) table.checkout_latest() - arrow_table = table.to_arrow() - - # Check if any of the new fields already exist in the table. - existing_fields = set(arrow_table.schema.names) - new_fields = [field for field in field_schemas if field.name not in existing_fields] - if not new_fields: - # All fields already present, skip. - return None + try: + # Use DataFusion SQL syntax to alter fields without loading data into client memory. + # Now, the most efficient way to modify column values is in LanceDB. + new_fields = { + field.name: f"CAST(NULL AS {arrow_datatype_to_fusion_datatype(field.type)})" + for field in field_schemas + } + table.add_columns(new_fields) - null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + # Make new columns nullable in the Arrow schema. + # Necessary because the Datafusion SQL API doesn't set new columns as nullable by default. + for field in field_schemas: + table.alter_columns({"path": field.name, "nullable": field.nullable}) - for field, null_array in zip(new_fields, null_arrays): - arrow_table = arrow_table.append_column(field, null_array) + # TODO: Update method below doesn't work for bulk NULL assignments, raise with LanceDB developers. + # table.update(values={field.name: None}) - try: - return self.db_client.create_table(table_name, arrow_table, mode="overwrite") except OSError: # Error occurred while creating the table, skip. return None @@ -376,36 +484,31 @@ def add_table_fields( def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns( + new_columns: List[TColumnSchema] = self.schema.get_new_table_columns( table_name, existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - embedding_fields: List[str] = get_columns_names_with_prop( - self.schema.get_table(table_name), VECTORIZE_HINT - ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") - if len(new_columns) > 0: + if new_columns: if exists: field_schemas: List[TArrowField] = [ make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) - self.add_table_fields(fq_table_name, field_schemas) + self.extend_lancedb_table_schema(fq_table_name, field_schemas) else: if table_name not in self.schema.dlt_table_names(): embedding_fields = get_columns_names_with_prop( self.schema.get_table(table_name=table_name), VECTORIZE_HINT ) vector_field_name = self.vector_field_name - id_field_name = self.id_field_name embedding_model_func = self.model_func embedding_model_dimensions = self.config.embedding_model_dimensions else: embedding_fields = None vector_field_name = None - id_field_name = None embedding_model_func = None embedding_model_dimensions = None @@ -417,7 +520,6 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: embedding_model_func=embedding_model_func, embedding_model_dimensions=embedding_model_dimensions, vector_field_name=vector_field_name, - id_field_name=id_field_name, ) fq_table_name = self.make_qualified_table_name(table_name) self.create_table(fq_table_name, table_schema) @@ -446,7 +548,8 @@ def update_schema_in_storage(self) -> None: write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) - upload_batch( + + write_records( records, db_client=self.db_client, table_name=fq_version_table_name, @@ -459,15 +562,17 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) - state_table_: Table = self.db_client.open_table(fq_state_table_name) + state_table_: "lancedb.table.Table" = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() - loads_table_: Table = self.db_client.open_table(fq_loads_table_name) + loads_table_: "lancedb.table.Table" = self.db_client.open_table(fq_loads_table_name) loads_table_.checkout_latest() # normalize property names p_load_id = self.schema.naming.normalize_identifier("load_id") - p_dlt_load_id = self.schema.naming.normalize_identifier(C_DLT_LOAD_ID) + p_dlt_load_id = self.schema.naming.normalize_identifier( + self.schema.data_item_normalizer.c_dlt_load_id # type: ignore[attr-defined] + ) p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") p_status = self.schema.naming.normalize_identifier("status") p_version = self.schema.naming.normalize_identifier("version") @@ -476,7 +581,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: p_created_at = self.schema.naming.normalize_identifier("created_at") p_version_hash = self.schema.naming.normalize_identifier("version_hash") - # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less + # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as little # data into memory as possible. state_table = ( state_table_.search() @@ -508,7 +613,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table: "lancedb.table.Table" = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() p_version_hash = self.schema.naming.normalize_identifier("version_hash") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") @@ -524,8 +629,6 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI ) ).to_list() - # LanceDB's ORDER BY clause doesn't seem to work. - # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], @@ -543,7 +646,7 @@ def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaIn """Retrieves newest schema from destination storage.""" fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table: "lancedb.table.Table" = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() p_version_hash = self.schema.naming.normalize_identifier("version_hash") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") @@ -558,8 +661,6 @@ def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaIn query = query.where(f'`{p_schema_name}` = "{schema_name}"', prefilter=True) schemas = query.to_list() - # LanceDB's ORDER BY clause doesn't seem to work. - # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], @@ -591,16 +692,14 @@ def complete_load(self, load_id: str) -> None: self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier( - "schema_version_hash" - ): None, # Payload schema must match the target schema. + self.schema.naming.normalize_identifier("schema_version_hash"): None, } ] fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) - upload_batch( + write_records( records, db_client=self.db_client, table_name=fq_loads_table_name, @@ -610,80 +709,152 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - return LanceDBLoadJob( - file_path=file_path, - type_mapper=self.type_mapper, - model_func=self.model_func, - fq_table_name=self.make_qualified_table_name(table["name"]), + if ReferenceFollowupJobRequest.is_reference_job(file_path): + return LanceDBRemoveOrphansJob(file_path) + else: + return LanceDBLoadJob(file_path, table) + + def create_table_chain_completed_followup_jobs( + self, + table_chain: Sequence[TTableSchema], + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + ) -> List[FollowupJobRequest]: + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs # type: ignore[arg-type] ) + # Orphan removal is only supported for upsert strategy because we need a deterministic key hash. + first_table_in_chain = table_chain[0] + if first_table_in_chain.get( + "write_disposition" + ) == "merge" and not first_table_in_chain.get(NO_REMOVE_ORPHANS_HINT): + all_job_paths_ordered = [ + job.file_path + for table in table_chain + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table.get("name") + ] + root_table_file_name = FileStorage.get_file_name_from_file_path( + all_job_paths_ordered[0] + ) + jobs.append(ReferenceFollowupJobRequest(root_table_file_name, all_job_paths_ordered)) + return jobs def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LanceDBLoadJob(RunnableLoadJob): +class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema def __init__( self, file_path: str, - type_mapper: DataTypeMapper, - model_func: TextEmbeddingFunction, - fq_table_name: str, + table_schema: TTableSchema, ) -> None: super().__init__(file_path) - self._type_mapper = type_mapper - self._fq_table_name: str = fq_table_name - self._model_func = model_func self._job_client: "LanceDBClient" = None + self._table_schema: TTableSchema = table_schema def run(self) -> None: - self._db_client: DBConnection = self._job_client.db_client - self._embedding_model_func: TextEmbeddingFunction = self._model_func - self._embedding_model_dimensions: int = self._job_client.config.embedding_model_dimensions - self._id_field_name: str = self._job_client.config.id_field_name - - unique_identifiers: Sequence[str] = list_merge_identifiers(self._load_table) + db_client: DBConnection = self._job_client.db_client + fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) write_disposition: TWriteDisposition = cast( TWriteDisposition, self._load_table.get("write_disposition", "append") ) - with FileStorage.open_zipsafe_ro(self._file_path) as f: - records: List[DictStrAny] = [json.loads(line) for line in f] + with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: + arrow_table: pa.Table = pq.read_table(f) # Replace empty strings with placeholder string if OpenAI is used. # https://github.com/lancedb/lancedb/issues/1577#issuecomment-2318104218. if (self._job_client.config.embedding_model_provider == "openai") and ( source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) ): - records = [ - { - k: EMPTY_STRING_PLACEHOLDER if k in source_columns and v in ("", None) else v - for k, v in record.items() - } - for record in records - ] - - if self._load_table not in self._schema.dlt_tables(): - for record in records: - # Add reserved ID fields. - uuid_id = ( - generate_uuid(record, unique_identifiers, self._fq_table_name) - if unique_identifiers - else str(uuid.uuid4()) - ) - record.update({self._id_field_name: uuid_id}) + arrow_table = fill_empty_source_column_values_with_placeholder( + arrow_table, source_columns, EMPTY_STRING_PLACEHOLDER + ) - # LanceDB expects all fields in the target arrow table to be present in the data payload. - # We add and set these missing fields, that are fields not present in the target schema, to NULL. - missing_fields = set(self._load_table["columns"]) - set(record) - for field in missing_fields: - record[field] = None + # We need upsert merge's deterministic _dlt_id to perform orphan removal. + # Hence, we require at least a primary key on the root table if the merge disposition is chosen. + if ( + (self._load_table not in self._schema.dlt_table_names()) + and not is_nested_table(self._load_table) # Is root table. + and (write_disposition == "merge") + and (not get_columns_names_with_prop(self._load_table, "primary_key")) + ): + raise DestinationTerminalException( + "LanceDB's write disposition requires at least one explicit primary key." + ) - upload_batch( - records, - db_client=self._db_client, - table_name=self._fq_table_name, + dlt_id = self._schema.naming.normalize_identifier( + self._schema.data_item_normalizer.c_dlt_id # type: ignore[attr-defined] + ) + write_records( + arrow_table, + db_client=db_client, + table_name=fq_table_name, write_disposition=write_disposition, - id_field_name=self._id_field_name, + merge_key=dlt_id, ) + + +class LanceDBRemoveOrphansJob(RunnableLoadJob): + orphaned_ids: Set[str] + + def __init__( + self, + file_path: str, + ) -> None: + super().__init__(file_path) + self._job_client: "LanceDBClient" = None + self.references = ReferenceFollowupJobRequest.resolve_references(file_path) + + def run(self) -> None: + dlt_load_id = self._schema.data_item_normalizer.c_dlt_load_id # type: ignore[attr-defined] + dlt_id = self._schema.data_item_normalizer.c_dlt_id # type: ignore[attr-defined] + dlt_root_id = self._schema.data_item_normalizer.c_dlt_root_id # type: ignore[attr-defined] + + db_client: DBConnection = self._job_client.db_client + table_lineage: TTableLineage = [ + TableJob( + table_schema=self._schema.get_table( + ParsedLoadJobFileName.parse(file_path_).table_name + ), + table_name=ParsedLoadJobFileName.parse(file_path_).table_name, + file_path=file_path_, + ) + for file_path_ in self.references + ] + + for job in table_lineage: + target_is_root_table = not is_nested_table(job.table_schema) + fq_table_name = self._job_client.make_qualified_table_name(job.table_name) + file_path = job.file_path + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + + if target_is_root_table: + canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( + job.table_schema + ) + filter_condition = create_filter_condition( + canonical_doc_id_field, payload_arrow_table[canonical_doc_id_field] + ) + merge_key = dlt_load_id + + else: + filter_condition = create_filter_condition( + dlt_root_id, + payload_arrow_table[dlt_root_id], + ) + merge_key = dlt_id + + write_records( + payload_arrow_table, + db_client=db_client, + table_name=fq_table_name, + write_disposition="merge", + merge_key=merge_key, + remove_orphans=True, + filter_condition=filter_condition, + ) diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index 27c6fb33a1..25dfbc840a 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -1,6 +1,5 @@ """Utilities for creating arrow schemas from table schemas.""" - -from dlt.common.json import json +from collections import namedtuple from typing import ( List, cast, @@ -11,17 +10,19 @@ from lancedb.embeddings import TextEmbeddingFunction # type: ignore from typing_extensions import TypeAlias +from dlt.common.destination.capabilities import DataTypeMapper +from dlt.common.json import json from dlt.common.schema import Schema, TColumnSchema from dlt.common.typing import DictStrAny -from dlt.common.destination.capabilities import DataTypeMapper - TArrowSchema: TypeAlias = pa.Schema TArrowDataType: TypeAlias = pa.DataType TArrowField: TypeAlias = pa.Field NULL_SCHEMA: TArrowSchema = pa.schema([]) """Empty pyarrow Schema with no fields.""" +TableJob = namedtuple("TableJob", ["table_schema", "table_name", "file_path"]) +TTableLineage: TypeAlias = List[TableJob] def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: @@ -42,7 +43,6 @@ def make_arrow_table_schema( table_name: str, schema: Schema, type_mapper: DataTypeMapper, - id_field_name: Optional[str] = None, vector_field_name: Optional[str] = None, embedding_fields: Optional[List[str]] = None, embedding_model_func: Optional[TextEmbeddingFunction] = None, @@ -51,9 +51,6 @@ def make_arrow_table_schema( """Creates a PyArrow schema from a dlt schema.""" arrow_schema: List[TArrowField] = [] - if id_field_name: - arrow_schema.append(pa.field(id_field_name, pa.string())) - if embedding_fields: # User's provided dimension config, if provided, takes precedence. vec_size = embedding_model_dimensions or embedding_model_func.ndims() @@ -83,3 +80,22 @@ def make_arrow_table_schema( metadata["embedding_functions"] = json.dumps(embedding_functions).encode("utf-8") return pa.schema(arrow_schema, metadata=metadata) + + +def arrow_datatype_to_fusion_datatype(arrow_type: TArrowSchema) -> str: + type_map = { + pa.bool_(): "BOOLEAN", + pa.int64(): "BIGINT", + pa.float64(): "DOUBLE", + pa.utf8(): "STRING", + pa.binary(): "BYTEA", + pa.date32(): "DATE", + } + + if isinstance(arrow_type, pa.Decimal128Type): + return f"DECIMAL({arrow_type.precision}, {arrow_type.scale})" + + if isinstance(arrow_type, pa.TimestampType): + return "TIMESTAMP" + + return type_map.get(arrow_type, "UNKNOWN") diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index aeacd4d34b..56991b090f 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,13 +1,16 @@ import os -import uuid -from typing import Sequence, Union, Dict +from typing import Union, Dict, List +import pyarrow as pa + +from dlt.common import logger +from dlt.common.data_writers.escape import escape_lancedb_literal +from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.schema import TTableSchema -from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.typing import DictStrAny +from dlt.common.schema.utils import get_columns_names_with_prop, get_first_column_name_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider - +EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { "cohere": "COHERE_API_KEY", "gemini-text": "GOOGLE_API_KEY", @@ -16,40 +19,55 @@ } -def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: - """Generates deterministic UUID - used for deduplication. +def set_non_standard_providers_environment_variables( + embedding_model_provider: TEmbeddingProvider, api_key: Union[str, None] +) -> None: + if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: + os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" - Args: - data (Dict[str, Any]): Arbitrary data to generate UUID for. - unique_identifiers (Sequence[str]): A list of unique identifiers. - table_name (str): LanceDB table name. - Returns: - str: A string representation of the generated UUID. - """ - data_id = "_".join(str(data[key]) for key in unique_identifiers) - return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) +def get_canonical_vector_database_doc_id_merge_key( + load_table: TTableSchema, +) -> str: + if merge_key := get_first_column_name_with_prop(load_table, "merge_key"): + return merge_key + elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): + # No merge key defined, warn and assume the first element of the primary key is `doc_id`. + logger.warning( + "Merge strategy selected without defined merge key - using the first element of the" + f" primary key ({primary_key}) as merge key." + ) + return primary_key[0] + else: + raise DestinationTerminalException( + "You must specify at least a primary key in order to perform orphan removal." + ) -def list_merge_identifiers(table_schema: TTableSchema) -> Sequence[str]: - """Returns a list of merge keys for a table used for either merging or deduplication. +def fill_empty_source_column_values_with_placeholder( + table: pa.Table, source_columns: List[str], placeholder: str +) -> pa.Table: + """ + Replaces empty strings and null values in the specified source columns of an Arrow table with a placeholder string. Args: - table_schema (TTableSchema): a dlt table schema. + table (pa.Table): The input Arrow table. + source_columns (List[str]): A list of column names to replace empty strings and null values in. + placeholder (str): The placeholder string to use for replacement. Returns: - Sequence[str]: A list of unique column identifiers. + pa.Table: The modified Arrow table with empty strings and null values replaced in the specified columns. """ - if table_schema.get("write_disposition") == "merge": - primary_keys = get_columns_names_with_prop(table_schema, "primary_key") - merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys)): - return join_keys - return get_columns_names_with_prop(table_schema, "unique") + for col_name in source_columns: + column = table[col_name] + filled_column = pa.compute.fill_null(column, fill_value=placeholder) + new_column = pa.compute.replace_substring_regex( + filled_column, pattern=r"^$", replacement=placeholder + ) + table = table.set_column(table.column_names.index(col_name), col_name, new_column) + return table -def set_non_standard_providers_environment_variables( - embedding_model_provider: TEmbeddingProvider, api_key: Union[str, None] -) -> None: - if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: - os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" +def create_filter_condition(field_name: str, array: pa.Array) -> str: + array_py = array.to_pylist() + return f"{field_name} IN ({', '.join(map(escape_lancedb_literal, array_py))})" diff --git a/dlt/destinations/impl/qdrant/qdrant_adapter.py b/dlt/destinations/impl/qdrant/qdrant_adapter.py index abe301fff0..bbc2d719a8 100644 --- a/dlt/destinations/impl/qdrant/qdrant_adapter.py +++ b/dlt/destinations/impl/qdrant/qdrant_adapter.py @@ -34,7 +34,7 @@ def qdrant_adapter( """ resource = get_resource_for_adapter(data) - column_hints: TTableSchemaColumns = {} + column_hints: TTableSchemaColumns = None if embed: if isinstance(embed, str): @@ -44,6 +44,7 @@ def qdrant_adapter( "embed must be a list of column names or a single column name as a string" ) + column_hints = {} for column_name in embed: column_hints[column_name] = { "name": column_name, diff --git a/dlt/destinations/impl/weaviate/weaviate_adapter.py b/dlt/destinations/impl/weaviate/weaviate_adapter.py index 9bd0b41783..0ca9047528 100644 --- a/dlt/destinations/impl/weaviate/weaviate_adapter.py +++ b/dlt/destinations/impl/weaviate/weaviate_adapter.py @@ -87,6 +87,7 @@ def weaviate_adapter( TOKENIZATION_HINT: method, # type: ignore } + # this makes sure that {} as column_hints never gets into apply_hints (that would reset existing columns) if not column_hints: raise ValueError("Either 'vectorize' or 'tokenization' must be specified.") else: diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index ae27213a7c..f59f087f4f 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -78,7 +78,6 @@ def from_table_chain( job = cls(file_info.file_name()) job._save_text_file("\n".join(sql)) except Exception as e: - # raise exception with some context raise SqlJobCreationException(e, table_chain) from e return job diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index a85d6ddc7e..b2aec665ab 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -179,19 +179,61 @@ info = pipeline.run( ### Merge -The [merge](../../general-usage/incremental-loading.md) write disposition merges the data from the resource with the data at the destination based on a unique identifier. +The [merge](../../general-usage/incremental-loading.md) write disposition merges the data from the resource with the data at the destination based on a unique identifier. The LanceDB destination merge write disposition only supports upsert strategy. This updates existing records and inserts new ones based on a unique identifier. + +You can specify the merge disposition, primary key, and merge key either in a resource or adapter: + +```py +@dlt.resource( + primary_key=["doc_id", "chunk_id"], + merge_key=["doc_id"], + write_disposition={"disposition": "merge", "strategy": "upsert"}, +) +def my_rag_docs( + data: List[DictStrAny], +) -> Generator[List[DictStrAny], None, None]: + yield data +``` + +Or: + +```py +pipeline.run( + lancedb_adapter( + my_new_rag_docs, + merge_key="doc_id" + ), + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_id"], +) +``` + +The `primary_key` uniquely identifies each record, typically comprising a document ID and a chunk ID. +The `merge_key`, which cannot be compound, should correspond to the canonical `doc_id` used in vector databases and represent the document identifier in your data model. +It must be the first element of the `primary_key`. +This `merge_key` is crucial for document identification and orphan removal during merge operations. +This structure ensures proper record identification and maintains consistency with vector database concepts. + + +#### Orphan Removal + +LanceDB **automatically removes orphaned chunks** when updating or deleting parent documents during a merge operation. To disable this feature: ```py pipeline.run( lancedb_adapter( movies, embed="title", + no_remove_orphans=True # Disable with the `no_remove_orphans` flag. ), - write_disposition="merge", - primary_key="id", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_id"], ) ``` +Note: While it's possible to omit the `merge_key` for brevity (in which case it is assumed to be the first entry of `primary_key`), +explicitly specifying both is recommended for clarity. + ### Append This is the default disposition. It will append the data to the existing data in the destination. @@ -200,7 +242,6 @@ This is the default disposition. It will append the data to the existing data in - `dataset_separator`: The character used to separate the dataset name from table names. Defaults to "___". - `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector". -- `id_field_name`: The name of the special field used for deduplication and merging. Defaults to "id__". - `max_retries`: The maximum number of retries for embedding operations. Set to 0 to disable retries. Defaults to 3. ## dbt support diff --git a/poetry.lock b/poetry.lock index 5d81d06969..f4667da374 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2167,32 +2167,33 @@ typing-extensions = ">=3.10.0" [[package]] name = "databricks-sql-connector" -version = "3.3.0" +version = "2.9.6" description = "Databricks SQL Connector for Python" optional = true -python-versions = "<4.0.0,>=3.8.0" +python-versions = "<4.0.0,>=3.7.1" files = [ - {file = "databricks_sql_connector-3.3.0-py3-none-any.whl", hash = "sha256:55ee5a4a11291bf91a235ac76e41b419ddd66a9a321065a8bfaf119acbb26d6b"}, - {file = "databricks_sql_connector-3.3.0.tar.gz", hash = "sha256:19e82965da4c86574adfe9f788c17b4494d98eb8075ba4fd4306573d2edbf194"}, + {file = "databricks_sql_connector-2.9.6-py3-none-any.whl", hash = "sha256:d830abf86e71d2eb83c6a7b7264d6c03926a8a83cec58541ddd6b83d693bde8f"}, + {file = "databricks_sql_connector-2.9.6.tar.gz", hash = "sha256:e55f5b8ede8ae6c6f31416a4cf6352f0ac019bf6875896c668c7574ceaf6e813"}, ] [package.dependencies] +alembic = ">=1.0.11,<2.0.0" lz4 = ">=4.0.2,<5.0.0" numpy = [ - {version = ">=1.16.6,<2.0.0", markers = "python_version >= \"3.8\" and python_version < \"3.11\""}, - {version = ">=1.23.4,<2.0.0", markers = "python_version >= \"3.11\""}, + {version = ">=1.16.6", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=1.23.4", markers = "python_version >= \"3.11\""}, ] oauthlib = ">=3.1.0,<4.0.0" openpyxl = ">=3.0.10,<4.0.0" -pandas = {version = ">=1.2.5,<2.2.0", markers = "python_version >= \"3.8\""} -pyarrow = ">=14.0.1,<17" +pandas = {version = ">=1.2.5,<3.0.0", markers = "python_version >= \"3.8\""} +pyarrow = [ + {version = ">=6.0.0", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=10.0.1", markers = "python_version >= \"3.11\""}, +] requests = ">=2.18.1,<3.0.0" -thrift = ">=0.16.0,<0.21.0" -urllib3 = ">=1.26" - -[package.extras] -alembic = ["alembic (>=1.0.11,<2.0.0)", "sqlalchemy (>=2.0.21)"] -sqlalchemy = ["sqlalchemy (>=2.0.21)"] +sqlalchemy = ">=1.3.24,<2.0.0" +thrift = ">=0.16.0,<0.17.0" +urllib3 = ">=1.0" [[package]] name = "db-dtypes" @@ -6760,52 +6761,55 @@ files = [ [[package]] name = "pyarrow" -version = "16.1.0" +version = "17.0.0" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, - {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, - {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, - {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, - {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, - {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, - {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, - {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, + {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, + {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, + {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] numpy = ">=1.16.6" +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + [[package]] name = "pyasn1" version = "0.5.0" @@ -9867,4 +9871,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "97666ad4613f07d95c5388bae41befe6cc10c88d02ee8f1cee27b161e13729f1" +content-hash = "6393c17b6865a78b3adb2732e4b2e416b7dc869f07649e60cc15352cab49444f" diff --git a/pyproject.toml b/pyproject.toml index 165364908c..38262a95d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -239,7 +239,7 @@ dbt-duckdb = ">=1.2.0" pymongo = ">=4.3.3" pandas = ">2" alive-progress = ">=3.0.1" -pyarrow = ">=14.0.0" +pyarrow = ">=17.0.0" psycopg2-binary = ">=2.9" lancedb = { version = ">=0.8.2", markers = "python_version >= '3.9'", allow-prereleases = true } openai = ">=1.45" diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py new file mode 100644 index 0000000000..f04c846df7 --- /dev/null +++ b/tests/load/lancedb/test_merge.py @@ -0,0 +1,425 @@ +from typing import Iterator, List, Generator, Any + +import numpy as np +import pandas as pd +import pytest +from lancedb.table import Table # type: ignore +from pandas import DataFrame +from pandas.testing import assert_frame_equal + +import dlt +from dlt.common.typing import DictStrAny, DictStrStr +from dlt.common.utils import uniq_id +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + lancedb_adapter, +) +from tests.load.lancedb.utils import chunk_document +from tests.load.utils import ( + drop_active_pipeline_data, + sequence_generator, +) +from tests.pipeline.utils import ( + assert_load_info, +) + + +# Mark all tests as essential, don't remove. +pytestmark = pytest.mark.essential + + +@pytest.fixture(autouse=True) +def drop_lancedb_data() -> Iterator[None]: + yield + drop_active_pipeline_data() + + +def test_lancedb_remove_nested_orphaned_records() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="parent", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key="id", + merge_key="id", + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + run_1 = [ + { + "id": 1, + "child": [ + {"bar": 1, "grandchild": [{"baz": 1}, {"baz": 2}]}, + {"bar": 2, "grandchild": [{"baz": 3}]}, + ], + }, + {"id": 2, "child": [{"bar": 3, "grandchild": [{"baz": 4}]}]}, + { + "id": 3, + "child": [ + {"bar": 10, "grandchild": [{"baz": 5}]}, + {"bar": 11, "grandchild": [{"baz": 6}, {"baz": 7}]}, + ], + }, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + { + "id": 1, + "child": [{"bar": 1, "grandchild": [{"baz": 1}]}], + }, # Removes bar_2, baz_2 and baz_3. + { + "id": 2, + "child": [{"bar": 4, "grandchild": [{"baz": 8}]}], + }, # Removes bar_3, baz_4. + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_parent_data = pd.DataFrame( + data=[ + {"id": 1}, + {"id": 2}, + {"id": 3}, + ] + ) + + expected_child_data = pd.DataFrame( + data=[ + {"bar": 1}, + {"bar": 4}, + {"bar": 10}, + {"bar": 11}, + ] + ) + + expected_grandchild_data = pd.DataFrame( + data=[ + {"baz": 1}, + {"baz": 8}, + {"baz": 5}, + {"baz": 6}, + {"baz": 7}, + ] + ) + + parent_table_name = client.make_qualified_table_name("parent") # type: ignore[attr-defined] + child_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] + grandchild_table_name = client.make_qualified_table_name( # type: ignore[attr-defined] + "parent__child__grandchild" + ) + + parent_tbl = client.db_client.open_table(parent_table_name) # type: ignore[attr-defined] + child_tbl = client.db_client.open_table(child_table_name) # type: ignore[attr-defined] + grandchild_tbl = client.db_client.open_table(grandchild_table_name) # type: ignore[attr-defined] + + actual_parent_df = parent_tbl.to_pandas().sort_values(by="id").reset_index(drop=True) + actual_child_df = child_tbl.to_pandas().sort_values(by="bar").reset_index(drop=True) + actual_grandchild_df = ( + grandchild_tbl.to_pandas().sort_values(by="baz").reset_index(drop=True) + ) + + expected_parent_data = expected_parent_data.sort_values(by="id").reset_index(drop=True) + expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) + expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( + drop=True + ) + + assert_frame_equal(actual_parent_df[["id"]], expected_parent_data) + assert_frame_equal(actual_child_df[["bar"]], expected_child_data) + assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) + + +def test_lancedb_remove_orphaned_records_root_table() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_root_table", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + lancedb_adapter(identity_resource) + + run_1 = [ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2a"}, + {"doc_id": 2, "chunk_hash": "2b"}, + {"doc_id": 2, "chunk_hash": "2c"}, + {"doc_id": 3, "chunk_hash": "3a"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) + )[["doc_id", "chunk_hash"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_lancedb_remove_orphaned_records_root_table_string_doc_id() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_root_table", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + lancedb_adapter(identity_resource) + + run_1 = [ + {"doc_id": "A", "chunk_hash": "1a"}, + {"doc_id": "B", "chunk_hash": "2a"}, + {"doc_id": "B", "chunk_hash": "2b"}, + {"doc_id": "B", "chunk_hash": "2c"}, + {"doc_id": "C", "chunk_hash": "3a"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": "B", "chunk_hash": "2d"}, + {"doc_id": "B", "chunk_hash": "2e"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": "A", "chunk_hash": "1a"}, + {"doc_id": "B", "chunk_hash": "2d"}, + {"doc_id": "B", "chunk_hash": "2e"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) + )[["doc_id", "chunk_hash"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: + @dlt.resource( + write_disposition={"disposition": "merge", "strategy": "upsert"}, + table_name="document", + primary_key=["doc_id", "chunk"], + merge_key="doc_id", + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + for chunk in chunk_document(doc["text"]): + yield {"doc_id": doc_id, "doc_text": doc["text"], "chunk": chunk} + + @dlt.source() + def documents_source( + docs: List[DictStrAny], + ) -> Any: + return documents(docs) + + lancedb_adapter( + documents, + embed=["chunk"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_with_embeddings", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_{uniq_id()}", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ + { + "text": "This is the first document, but it has been updated with new content.", + "doc_id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "doc_id": 3, + }, + ] + + info = pipeline.run(documents_source(updated_docs)) + assert_load_info(info) + + with pipeline.destination_client() as client: + embeddings_table_name = client.make_qualified_table_name("document") # type: ignore[attr-defined] + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + + # Check (non-empty) embeddings as present, and that orphaned embeddings have been discarded. + assert len(df) == 21 + assert "vector" in df.columns + for _, vector in enumerate(df["vector"]): + assert isinstance(vector, np.ndarray) + assert vector.size > 0 + + +def test_lancedb_compound_merge_key_root_table() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_compound_merge_key", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id", "chunk_hash"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + lancedb_adapter(identity_resource, no_remove_orphans=True) + + run_1 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash", "foo"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash", "foo"]).reset_index(drop=True) + )[["doc_id", "chunk_hash", "foo"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_must_provide_at_least_primary_key_on_merge_disposition() -> None: + """We need upsert merge's deterministic _dlt_id to perform orphan removal. + Hence, we require at least the primary key required (raises exception if missing). + Specify a merge key for custom orphan identification.""" + generator_instance1 = sequence_generator() + + @dlt.resource(write_disposition={"disposition": "merge", "strategy": "upsert"}) + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + pipeline = dlt.pipeline( + pipeline_name="test_must_provide_both_primary_and_merge_key_on_merge_disposition", + destination="lancedb", + dataset_name=( + f"test_must_provide_both_primary_and_merge_key_on_merge_disposition{uniq_id()}" + ), + ) + with pytest.raises(Exception): + load_info = pipeline.run( + some_data(), + ) + assert_load_info(load_info) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 6cd0abd587..345934fb29 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,25 +1,30 @@ import multiprocessing -from typing import Iterator, Generator, Any, List, Mapping +import os +from typing import Iterator, Generator, Any, List +from typing import Mapping +from typing import Union, Dict import pytest -import lancedb # type: ignore -from lancedb import DBConnection +from lancedb import DBConnection # type: ignore from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore +from lancedb.table import Table # type: ignore import dlt from dlt.common import json -from dlt.common.typing import DictStrStr, DictStrAny -from dlt.common.utils import uniq_id +from dlt.common.typing import DictStrAny +from dlt.common.typing import DictStrStr +from dlt.common.utils import uniq_id, digest128 from dlt.destinations.impl.lancedb.lancedb_adapter import ( lancedb_adapter, VECTORIZE_HINT, ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient -from tests.load.lancedb.utils import assert_table +from dlt.extract import DltResource +from tests.load.lancedb.utils import assert_table, chunk_document, mock_embed from tests.load.utils import sequence_generator, drop_active_pipeline_data from tests.pipeline.utils import assert_load_info -# Mark all tests as essential, do not remove. +# Mark all tests as essential, don't remove. pytestmark = pytest.mark.essential @@ -49,6 +54,22 @@ def some_data() -> Generator[DictStrStr, Any, None]: "x-lancedb-embed": True, } + lancedb_adapter( + some_data, + merge_key="content", + ) + + # via merge_key + assert some_data._hints["merge_key"] == "content" + + assert some_data.columns["content"] == { # type: ignore + "name": "content", + "data_type": "text", + "x-lancedb-embed": True, + } + + assert some_data.compute_table_schema()["columns"]["content"]["merge_key"] is True + def test_basic_state_and_schema() -> None: generator_instance1 = sequence_generator() @@ -118,14 +139,13 @@ def some_data() -> Generator[DictStrStr, Any, None]: def test_explicit_append() -> None: - """Append should work even when the primary key is specified.""" data = [ {"doc_id": 1, "content": "1"}, {"doc_id": 2, "content": "2"}, {"doc_id": 3, "content": "3"}, ] - @dlt.resource(primary_key="doc_id") + @dlt.resource() def some_data() -> Generator[List[DictStrAny], Any, None]: yield data @@ -142,6 +162,7 @@ def some_data() -> Generator[List[DictStrAny], Any, None]: info = pipeline.run( some_data(), ) + assert_load_info(info) assert_table(pipeline, "some_data", items=data) @@ -156,25 +177,22 @@ def some_data() -> Generator[List[DictStrAny], Any, None]: def test_pipeline_replace() -> None: - generator_instance1 = sequence_generator() - generator_instance2 = sequence_generator() + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "2" + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "2" + + generator_instance1, generator_instance2 = (sequence_generator(), sequence_generator()) @dlt.resource def some_data() -> Generator[DictStrStr, Any, None]: yield from next(generator_instance1) - lancedb_adapter( - some_data, - embed=["content"], - ) - uid = uniq_id() pipeline = dlt.pipeline( pipeline_name="test_pipeline_replace", destination="lancedb", dataset_name="test_pipeline_replace_dataset" - + uid, # lancedb doesn't mandate any name normalization + + uid, # Lancedb doesn't mandate any name normalization. ) info = pipeline.run( @@ -263,23 +281,11 @@ def test_pipeline_merge() -> None: }, ] - @dlt.resource(primary_key="doc_id") + @dlt.resource(primary_key=["doc_id"]) def movies_data() -> Any: yield data - @dlt.resource(primary_key="doc_id", merge_key=["merge_id", "title"]) - def movies_data_explicit_merge_keys() -> Any: - yield data - - lancedb_adapter( - movies_data, - embed=["description"], - ) - - lancedb_adapter( - movies_data_explicit_merge_keys, - embed=["description"], - ) + lancedb_adapter(movies_data, embed=["description"], no_remove_orphans=True) pipeline = dlt.pipeline( pipeline_name="movies", @@ -288,7 +294,7 @@ def movies_data_explicit_merge_keys() -> Any: ) info = pipeline.run( movies_data(), - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, dataset_name=f"MoviesDataset{uniq_id()}", ) assert_load_info(info) @@ -299,26 +305,11 @@ def movies_data_explicit_merge_keys() -> Any: info = pipeline.run( movies_data(), - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, ) assert_load_info(info) assert_table(pipeline, "movies_data", items=data) - info = pipeline.run( - movies_data(), - write_disposition="merge", - ) - assert_load_info(info) - assert_table(pipeline, "movies_data", items=data) - - # Test with explicit merge keys. - info = pipeline.run( - movies_data_explicit_merge_keys(), - write_disposition="merge", - ) - assert_load_info(info) - assert_table(pipeline, "movies_data_explicit_merge_keys", items=data) - def test_pipeline_with_schema_evolution() -> None: data = [ @@ -388,9 +379,9 @@ def test_merge_github_nested() -> None: data = json.load(f) info = pipe.run( - lancedb_adapter(data[:17], embed=["title", "body"]), + lancedb_adapter(data[:17], embed=["title", "body"], no_remove_orphans=True), table_name="issues", - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, primary_key="id", ) assert_load_info(info) @@ -426,18 +417,116 @@ def test_merge_github_nested() -> None: def test_empty_dataset_allowed() -> None: # dataset_name is optional so dataset name won't be autogenerated when not explicitly passed. pipe = dlt.pipeline(destination="lancedb", dev_mode=True) - client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) # Dataset in load info is empty. assert info.dataset_name is None - client = pipe.destination_client() # type: ignore[assignment] - assert client.dataset_name is None - assert client.sentinel_table == "dltSentinelTable" + client = pipe.destination_client() + assert client.dataset_name is None # type: ignore + assert client.sentinel_table == "dltSentinelTable" # type: ignore assert_table(pipe, "content", expected_items_count=3) +def test_lancedb_remove_nested_orphaned_records_with_chunks() -> None: + @dlt.resource( + write_disposition={"disposition": "merge", "strategy": "upsert"}, + table_name="document", + primary_key=["doc_id"], + merge_key=["doc_id"], + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + chunks = chunk_document(doc["text"]) + embeddings = [ + { + "chunk_hash": digest128(chunk), + "chunk_text": chunk, + "embedding": mock_embed(), + } + for chunk in chunks + ] + yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} + + @dlt.source(max_table_nesting=1) + def documents_source( + docs: List[DictStrAny], + ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: + return documents(docs) + + pipeline = dlt.pipeline( + pipeline_name="chunked_docs", + destination="lancedb", + dataset_name="chunked_documents", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ + { + "text": "This is the first document, but it has been updated with new content.", + "doc_id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "doc_id": 3, + }, + ] + + info = pipeline.run(documents_source(updated_docs)) + assert_load_info(info) + + with pipeline.destination_client() as client: + # Orphaned chunks/documents must have been discarded. + # Shouldn't contain any text from `initial_docs' where doc_id=1. + expected_text = { + "Here's ano", + "ther docum", + "ent. It's ", + "a bit diff", + "erent from", + " the first", + " one.", + "This is th", + "e first do", + "cument, bu", + "t it has b", + "een update", + "d with new", + " content.", + "This is a ", + "completely", + " new docum", + "ent that w", + "asn't in t", + "he initial", + " set.", + } + + embeddings_table_name = client.make_qualified_table_name("document__embeddings") # type: ignore[attr-defined] + + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + assert set(df["chunk_text"]) == expected_text + + search_data = [ {"text": "Frodo was a happy puppy"}, {"text": "There are several kittens playing"}, diff --git a/tests/load/lancedb/test_utils.py b/tests/load/lancedb/test_utils.py new file mode 100644 index 0000000000..d7f9729f26 --- /dev/null +++ b/tests/load/lancedb/test_utils.py @@ -0,0 +1,46 @@ +import pyarrow as pa +import pytest + +from dlt.destinations.impl.lancedb.utils import ( + create_filter_condition, + fill_empty_source_column_values_with_placeholder, +) + + +# Mark all tests as essential, don't remove. +pytestmark = pytest.mark.essential + + +def test_fill_empty_source_column_values_with_placeholder() -> None: + data = [ + pa.array(["", "hello", ""]), + pa.array(["hello", None, ""]), + pa.array([1, 2, 3]), + pa.array(["world", "", "arrow"]), + ] + table = pa.Table.from_arrays(data, names=["A", "B", "C", "D"]) + + source_columns = ["A", "B"] + placeholder = "placeholder" + + new_table = fill_empty_source_column_values_with_placeholder(table, source_columns, placeholder) + + expected_data = [ + pa.array(["placeholder", "hello", "placeholder"]), + pa.array(["hello", "placeholder", "placeholder"]), + pa.array([1, 2, 3]), + pa.array(["world", "", "arrow"]), + ] + expected_table = pa.Table.from_arrays(expected_data, names=["A", "B", "C", "D"]) + assert new_table.equals(expected_table) + + +def test_create_filter_condition() -> None: + assert ( + create_filter_condition("_dlt_load_id", pa.array(["A", "B", "C'c\n"])) + == "_dlt_load_id IN ('A', 'B', 'C''c\\n')" + ) + assert ( + create_filter_condition("_dlt_load_id", pa.array([1.2, 3, 5 / 2])) + == "_dlt_load_id IN (1.2, 3.0, 2.5)" + ) diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index 7431e895b7..30430fe076 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -40,7 +40,7 @@ def assert_table( exists = client.table_exists(qualified_table_name) assert exists - records = client.db_client.open_table(qualified_table_name).search().limit(50).to_list() + records = client.db_client.open_table(qualified_table_name).search().limit(0).to_list() if expected_items_count is not None: assert expected_items_count == len(records) @@ -51,7 +51,6 @@ def assert_table( drop_keys = [ "_dlt_id", "_dlt_load_id", - dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector", ] objects_without_dlt_or_special_keys = [ @@ -72,3 +71,13 @@ def generate_embeddings( def ndims(self) -> int: return 2 + + +def mock_embed( + dim: int = 10, +) -> str: + return str(np.random.random_sample(dim)) + + +def chunk_document(doc: str, chunk_size: int = 10) -> List[str]: + return [doc[i : i + chunk_size] for i in range(0, len(doc), chunk_size)] diff --git a/tests/load/utils.py b/tests/load/utils.py index 3edf111a36..e7d6476a3a 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -1004,7 +1004,7 @@ def prepare_load_package( def sequence_generator() -> Generator[List[Dict[str, str]], None, None]: count = 1 while True: - yield [{"content": str(count + i)} for i in range(3)] + yield [{"content": str(count + i)} for i in range(2000)] count += 3