diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 06c8d7a95a..393e9e8508 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -79,6 +79,23 @@ def escape_duckdb_literal(v: Any) -> Any: return str(v) +def escape_lancedb_literal(v: Any) -> Any: + if isinstance(v, str): + # we escape extended string which behave like the redshift string + return _escape_extended(v, prefix="'") + if isinstance(v, (datetime, date, time)): + return f"'{v.isoformat()}'" + if isinstance(v, (list, dict)): + return _escape_extended(json.dumps(v), prefix="'") + # TODO: check how binaries are represented in fusion + if isinstance(v, bytes): + return f"from_base64('{base64.b64encode(v).decode('ascii')}')" + if v is None: + return "NULL" + + return str(v) + + MS_SQL_ESCAPE_DICT = { "'": "''", "\n": "' + CHAR(10) + N'", diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 75bd8ffa13..718427af87 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -224,7 +224,7 @@ def __init__( ) super().__init__(schema, config, sql_client) self.config: DatabricksClientConfiguration = config - self.sql_client: DatabricksSqlClient = sql_client # type: ignore[assignment] + self.sql_client: DatabricksSqlClient = sql_client self.type_mapper = self.capabilities.get_type_mapper() def create_load_job( diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 88d47410d5..8bff4e0d73 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -41,7 +41,7 @@ class DatabricksCursorImpl(DBApiCursorImpl): """Use native data frame support if available""" - native_cursor: DatabricksSqlCursor # type: ignore[assignment] + native_cursor: DatabricksSqlCursor vector_size: ClassVar[int] = 2048 # vector size is 2048 def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: @@ -144,7 +144,7 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB # db_args = kwargs or None db_args = args or kwargs or None - with self._conn.cursor() as curr: # type: ignore[assignment] + with self._conn.cursor() as curr: curr.execute(query, db_args) yield DatabricksCursorImpl(curr) # type: ignore[abstract] diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index 329132f495..8f6a192bb0 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -59,6 +59,7 @@ class LanceDBClientOptions(BaseConfiguration): "sentence-transformers", "huggingface", "colbert", + "ollama", ] @@ -92,8 +93,6 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): Make sure it corresponds with the associated embedding model's dimensionality.""" vector_field_name: str = "vector" """Name of the special field to store the vector embeddings.""" - id_field_name: str = "id__" - """Name of the special field to manage deduplication.""" sentinel_table_name: str = "dltSentinelTable" """Name of the sentinel table that encapsulates datasets. Since LanceDB has no concept of schemas, this table serves as a proxy to group related dlt tables together.""" diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index 8ce2217007..d0d22ed3fb 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -26,8 +26,8 @@ class lancedb(Destination[LanceDBClientConfiguration, "LanceDBClient"]): def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl"] + caps.preferred_loader_file_format = "parquet" + caps.supported_loader_file_formats = ["parquet", "reference"] caps.type_mapper = LanceDBTypeMapper caps.max_identifier_length = 200 @@ -42,6 +42,10 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.timestamp_precision = 6 caps.supported_replace_strategies = ["truncate-and-insert"] + caps.recommended_file_size = 128_000_000 + + caps.supported_merge_strategies = ["upsert"] + return caps @property diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 99d5ef43c6..4314dd703f 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -1,16 +1,20 @@ -from typing import Any +from typing import Any, Dict from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource +from dlt.extract.items import TTableHintTemplate VECTORIZE_HINT = "x-lancedb-embed" +NO_REMOVE_ORPHANS_HINT = "x-lancedb-remove-orphans" def lancedb_adapter( data: Any, embed: TColumnNames = None, + merge_key: TColumnNames = None, + no_remove_orphans: bool = False, ) -> DltResource: """Prepares data for the LanceDB destination by specifying which columns should be embedded. @@ -20,6 +24,10 @@ def lancedb_adapter( object. embed (TColumnNames, optional): Specify columns to generate embeddings for. It can be a single column name as a string, or a list of column names. + merge_key (TColumnNames, optional): Specify columns to merge on. + It can be a single column name as a string, or a list of column names. + no_remove_orphans (bool): Specify whether to remove orphaned records in child + tables with no parent records after merges to maintain referential integrity. Returns: DltResource: A resource with applied LanceDB-specific hints. @@ -34,7 +42,8 @@ def lancedb_adapter( """ resource = get_resource_for_adapter(data) - column_hints: TTableSchemaColumns = {} + additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} + column_hints: TTableSchemaColumns = None if embed: if isinstance(embed, str): @@ -43,6 +52,7 @@ def lancedb_adapter( raise ValueError( "'embed' must be a list of column names or a single column name as a string." ) + column_hints = {} for column_name in embed: column_hints[column_name] = { @@ -50,9 +60,16 @@ def lancedb_adapter( VECTORIZE_HINT: True, # type: ignore[misc] } - if not column_hints: - raise ValueError("A value for 'embed' must be specified.") + additional_table_hints[NO_REMOVE_ORPHANS_HINT] = no_remove_orphans + + if column_hints or additional_table_hints or merge_key: + resource.apply_hints( + merge_key=merge_key, columns=column_hints, additional_table_hints=additional_table_hints + ) else: - resource.apply_hints(columns=column_hints) + raise ValueError( + "You must must provide at least either the 'embed' or 'merge_key' or 'remove_orphans'" + " argument if using the adapter." + ) return resource diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8a347989a0..1a3e1a7d34 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,4 +1,3 @@ -import uuid from types import TracebackType from typing import ( List, @@ -12,15 +11,17 @@ Dict, Sequence, TYPE_CHECKING, + Set, ) -from dlt.common.destination.capabilities import DataTypeMapper import lancedb # type: ignore +import lancedb.table # type: ignore import pyarrow as pa +import pyarrow.parquet as pq from lancedb import DBConnection +from lancedb.common import DATA # type: ignore from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore -from lancedb.table import Table # type: ignore from numpy import ndarray from pyarrow import Array, ChunkedArray, ArrowInvalid @@ -39,53 +40,142 @@ StorageSchemaInfo, StateInfo, LoadJob, + HasFollowupJobs, + FollowupJobRequest, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import ( - C_DLT_LOAD_ID, + TColumnType, TTableSchemaColumns, TWriteDisposition, + TColumnSchema, + TTableSchema, ) -from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage -from dlt.common.typing import DictStrAny +from dlt.common.schema.utils import get_columns_names_with_prop, is_nested_table +from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, ) from dlt.destinations.impl.lancedb.exceptions import ( lancedb_error, ) -from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + VECTORIZE_HINT, + NO_REMOVE_ORPHANS_HINT, +) from dlt.destinations.impl.lancedb.schema import ( make_arrow_field_schema, make_arrow_table_schema, TArrowSchema, NULL_SCHEMA, TArrowField, + arrow_datatype_to_fusion_datatype, + TTableLineage, + TableJob, ) from dlt.destinations.impl.lancedb.utils import ( - list_merge_identifiers, - generate_uuid, set_non_standard_providers_environment_variables, + EMPTY_STRING_PLACEHOLDER, + fill_empty_source_column_values_with_placeholder, + get_canonical_vector_database_doc_id_merge_key, + create_filter_condition, ) +from dlt.destinations.job_impl import ReferenceFollowupJobRequest +from dlt.destinations.type_mapping import TypeMapperImpl if TYPE_CHECKING: NDArray = ndarray[Any, Any] else: NDArray = ndarray -EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" +TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +BATCH_PROCESS_CHUNK_SIZE = 10_000 + + +class LanceDBTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "text": pa.string(), + "double": pa.float64(), + "bool": pa.bool_(), + "bigint": pa.int64(), + "binary": pa.binary(), + "date": pa.date32(), + "json": pa.string(), + } + + sct_to_dbt = {} + + dbt_to_sct = { + pa.string(): "text", + pa.float64(): "double", + pa.bool_(): "bool", + pa.int64(): "bigint", + pa.binary(): "binary", + pa.date32(): "date", + } + + def to_db_decimal_type(self, column: TColumnSchema) -> pa.Decimal128Type: + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) + return pa.decimal128(precision, scale) + + def to_db_datetime_type( + self, + column: TColumnSchema, + table: TTableSchema = None, + ) -> pa.TimestampType: + column_name = column.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + if timezone is not None or precision is not None: + logger.warning( + "LanceDB does not currently support column flags for timezone or precision." + f" These flags were used in column '{column_name}'." + ) + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.timestamp(unit, "UTC") + + def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> pa.Time64Type: + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.time64(unit) + + def from_db_type( + self, + db_type: pa.DataType, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> TColumnType: + if isinstance(db_type, pa.TimestampType): + return dict( + data_type="timestamp", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Time64Type): + return dict( + data_type="time", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Decimal128Type): + precision, scale = db_type.precision, db_type.scale + if (precision, scale) == self.capabilities.wei_precision: + return cast(TColumnType, dict(data_type="wei")) + return dict(data_type="decimal", precision=precision, scale=scale) + return super().from_db_type(cast(str, db_type), precision, scale) # type: ignore -def upload_batch( - records: List[DictStrAny], +def write_records( + records: DATA, /, *, db_client: DBConnection, table_name: str, - write_disposition: TWriteDisposition, - id_field_name: Optional[str] = None, + write_disposition: Optional[TWriteDisposition] = "append", + merge_key: Optional[str] = None, + remove_orphans: Optional[bool] = False, + filter_condition: Optional[str] = None, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -93,8 +183,11 @@ def upload_batch( records: The data to be inserted as payload. db_client: The LanceDB client connection. table_name: The name of the table to insert into. - id_field_name: The name of the ID field for update/merge operations. + merge_key: Keys for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. + remove_orphans (bool): Whether to remove orphans after insertion or not (only merge disposition). + filter_condition (str): If None, then all such rows will be deleted. + Otherwise, the condition will be used as an SQL filter to limit what rows are deleted. Raises: ValueError: If the write disposition is unsupported, or `id_field_name` is not @@ -110,16 +203,17 @@ def upload_batch( ) from e try: - if write_disposition in ("append", "skip"): + if write_disposition in ("append", "skip", "replace"): tbl.add(records) - elif write_disposition == "replace": - tbl.add(records, mode="overwrite") elif write_disposition == "merge": - if not id_field_name: - raise ValueError("To perform a merge update, 'id_field_name' must be specified.") - tbl.merge_insert( - id_field_name - ).when_matched_update_all().when_not_matched_insert_all().execute(records) + if remove_orphans: + tbl.merge_insert(merge_key).when_not_matched_by_source_delete( + filter_condition + ).execute(records) + else: + tbl.merge_insert( + merge_key + ).when_matched_update_all().when_not_matched_insert_all().execute(records) else: raise DestinationTerminalException( f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" @@ -135,6 +229,8 @@ class LanceDBClient(JobClientBase, WithStateSync): """LanceDB destination handler.""" model_func: TextEmbeddingFunction + """The embedder callback used for each chunk.""" + dataset_name: str def __init__( self, @@ -152,6 +248,7 @@ def __init__( self.registry = EmbeddingFunctionRegistry.get_instance() self.type_mapper = self.capabilities.get_type_mapper() self.sentinel_table_name = config.sentinel_table_name + self.dataset_name = self.config.normalize_dataset_name(self.schema) embedding_model_provider = self.config.embedding_model_provider @@ -169,11 +266,6 @@ def __init__( ) self.vector_field_name = self.config.vector_field_name - self.id_field_name = self.config.id_field_name - - @property - def dataset_name(self) -> str: - return self.config.normalize_dataset_name(self.schema) @property def sentinel_table(self) -> str: @@ -187,7 +279,7 @@ def make_qualified_table_name(self, table_name: str) -> str: ) def get_table_schema(self, table_name: str) -> TArrowSchema: - schema_table: Table = self.db_client.open_table(table_name) + schema_table: "lancedb.table.Table" = self.db_client.open_table(table_name) schema_table.checkout_latest() schema = schema_table.schema return cast( @@ -196,13 +288,15 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + def create_table( + self, table_name: str, schema: TArrowSchema, mode: str = "create" + ) -> "lancedb.table.Table": """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: schema: The table schema to create. table_name: The name of the table to create. - mode (): The mode to use when creating the table. Can be either "create" or "overwrite". + mode (str): The mode to use when creating the table. Can be either "create" or "overwrite". By default, if the table already exists, an exception is raised. If you want to overwrite the table, use mode="overwrite". """ @@ -230,7 +324,7 @@ def query_table( Returns: A LanceDB query builder. """ - query_table: Table = self.db_client.open_table(table_name) + query_table: "lancedb.table.Table" = self.db_client.open_table(table_name) query_table.checkout_latest() return query_table.search(query=query) @@ -255,7 +349,7 @@ def drop_storage(self) -> None: Deletes all tables in the dataset and all data, as well as sentinel table associated with them. - If the dataset name was not provided, it deletes all the tables in the current schema. + If the dataset name wasn't provided, it deletes all the tables in the current schema. """ for table_name in self._get_table_names(): self.db_client.drop_table(table_name) @@ -282,7 +376,22 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: def is_storage_initialized(self) -> bool: return self.table_exists(self.sentinel_table) - def _create_sentinel_table(self) -> Table: + def verify_schema( + self, only_tables: Iterable[str] = None, new_jobs: Iterable[ParsedLoadJobFileName] = None + ) -> List[PreparedTableSchema]: + loaded_tables = super().verify_schema(only_tables, new_jobs) + # verify merge keys early + for load_table in loaded_tables: + if not is_nested_table(load_table) and not load_table.get(NO_REMOVE_ORPHANS_HINT): + if merge_key := get_columns_names_with_prop(load_table, "merge_key"): + if len(merge_key) > 1: + raise DestinationTerminalException( + "You cannot specify multiple merge keys with LanceDB orphan remove" + f" enabled: {merge_key}" + ) + return loaded_tables + + def _create_sentinel_table(self) -> "lancedb.table.Table": """Create an empty table to indicate that the storage is initialized.""" return self.create_table(schema=NULL_SCHEMA, table_name=self.sentinel_table) @@ -310,7 +419,7 @@ def update_stored_schema( # TODO: return a real updated table schema (like in SQL job client) self._execute_schema_update(only_tables) else: - logger.info( + logger.debug( f"Schema with hash {self.schema.stored_version_hash} " f"inserted at {schema_info.inserted_at} found " "in storage, no upgrade required" @@ -325,7 +434,7 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] try: fq_table_name = self.make_qualified_table_name(table_name) - table: Table = self.db_client.open_table(fq_table_name) + table: "lancedb.table.Table" = self.db_client.open_table(fq_table_name) table.checkout_latest() arrow_schema: TArrowSchema = table.schema except FileNotFoundError: @@ -341,34 +450,33 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] return True, table_schema @lancedb_error - def add_table_fields( - self, table_name: str, field_schemas: List[TArrowField] - ) -> Optional[Table]: - """Add multiple fields to the LanceDB table at once. + def extend_lancedb_table_schema(self, table_name: str, field_schemas: List[pa.Field]) -> None: + """Extend LanceDB table schema with empty columns. Args: - table_name: The name of the table to create the fields on. - field_schemas: The list of fields to create. + table_name: The name of the table to create the fields on. + field_schemas: The list of PyArrow Fields to create in the target LanceDB table. """ - table: Table = self.db_client.open_table(table_name) + table: "lancedb.table.Table" = self.db_client.open_table(table_name) table.checkout_latest() - arrow_table = table.to_arrow() - - # Check if any of the new fields already exist in the table. - existing_fields = set(arrow_table.schema.names) - new_fields = [field for field in field_schemas if field.name not in existing_fields] - if not new_fields: - # All fields already present, skip. - return None + try: + # Use DataFusion SQL syntax to alter fields without loading data into client memory. + # Now, the most efficient way to modify column values is in LanceDB. + new_fields = { + field.name: f"CAST(NULL AS {arrow_datatype_to_fusion_datatype(field.type)})" + for field in field_schemas + } + table.add_columns(new_fields) - null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + # Make new columns nullable in the Arrow schema. + # Necessary because the Datafusion SQL API doesn't set new columns as nullable by default. + for field in field_schemas: + table.alter_columns({"path": field.name, "nullable": field.nullable}) - for field, null_array in zip(new_fields, null_arrays): - arrow_table = arrow_table.append_column(field, null_array) + # TODO: Update method below doesn't work for bulk NULL assignments, raise with LanceDB developers. + # table.update(values={field.name: None}) - try: - return self.db_client.create_table(table_name, arrow_table, mode="overwrite") except OSError: # Error occurred while creating the table, skip. return None @@ -376,36 +484,31 @@ def add_table_fields( def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns( + new_columns: List[TColumnSchema] = self.schema.get_new_table_columns( table_name, existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - embedding_fields: List[str] = get_columns_names_with_prop( - self.schema.get_table(table_name), VECTORIZE_HINT - ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") - if len(new_columns) > 0: + if new_columns: if exists: field_schemas: List[TArrowField] = [ make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) - self.add_table_fields(fq_table_name, field_schemas) + self.extend_lancedb_table_schema(fq_table_name, field_schemas) else: if table_name not in self.schema.dlt_table_names(): embedding_fields = get_columns_names_with_prop( self.schema.get_table(table_name=table_name), VECTORIZE_HINT ) vector_field_name = self.vector_field_name - id_field_name = self.id_field_name embedding_model_func = self.model_func embedding_model_dimensions = self.config.embedding_model_dimensions else: embedding_fields = None vector_field_name = None - id_field_name = None embedding_model_func = None embedding_model_dimensions = None @@ -417,7 +520,6 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: embedding_model_func=embedding_model_func, embedding_model_dimensions=embedding_model_dimensions, vector_field_name=vector_field_name, - id_field_name=id_field_name, ) fq_table_name = self.make_qualified_table_name(table_name) self.create_table(fq_table_name, table_schema) @@ -446,7 +548,8 @@ def update_schema_in_storage(self) -> None: write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) - upload_batch( + + write_records( records, db_client=self.db_client, table_name=fq_version_table_name, @@ -459,15 +562,17 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) - state_table_: Table = self.db_client.open_table(fq_state_table_name) + state_table_: "lancedb.table.Table" = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() - loads_table_: Table = self.db_client.open_table(fq_loads_table_name) + loads_table_: "lancedb.table.Table" = self.db_client.open_table(fq_loads_table_name) loads_table_.checkout_latest() # normalize property names p_load_id = self.schema.naming.normalize_identifier("load_id") - p_dlt_load_id = self.schema.naming.normalize_identifier(C_DLT_LOAD_ID) + p_dlt_load_id = self.schema.naming.normalize_identifier( + self.schema.data_item_normalizer.c_dlt_load_id # type: ignore[attr-defined] + ) p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") p_status = self.schema.naming.normalize_identifier("status") p_version = self.schema.naming.normalize_identifier("version") @@ -476,7 +581,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: p_created_at = self.schema.naming.normalize_identifier("created_at") p_version_hash = self.schema.naming.normalize_identifier("version_hash") - # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less + # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as little # data into memory as possible. state_table = ( state_table_.search() @@ -508,7 +613,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table: "lancedb.table.Table" = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() p_version_hash = self.schema.naming.normalize_identifier("version_hash") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") @@ -524,8 +629,6 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI ) ).to_list() - # LanceDB's ORDER BY clause doesn't seem to work. - # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], @@ -543,7 +646,7 @@ def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaIn """Retrieves newest schema from destination storage.""" fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table: "lancedb.table.Table" = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() p_version_hash = self.schema.naming.normalize_identifier("version_hash") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") @@ -558,8 +661,6 @@ def get_stored_schema(self, schema_name: str = None) -> Optional[StorageSchemaIn query = query.where(f'`{p_schema_name}` = "{schema_name}"', prefilter=True) schemas = query.to_list() - # LanceDB's ORDER BY clause doesn't seem to work. - # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], @@ -591,16 +692,14 @@ def complete_load(self, load_id: str) -> None: self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier( - "schema_version_hash" - ): None, # Payload schema must match the target schema. + self.schema.naming.normalize_identifier("schema_version_hash"): None, } ] fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) - upload_batch( + write_records( records, db_client=self.db_client, table_name=fq_loads_table_name, @@ -610,80 +709,152 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - return LanceDBLoadJob( - file_path=file_path, - type_mapper=self.type_mapper, - model_func=self.model_func, - fq_table_name=self.make_qualified_table_name(table["name"]), + if ReferenceFollowupJobRequest.is_reference_job(file_path): + return LanceDBRemoveOrphansJob(file_path) + else: + return LanceDBLoadJob(file_path, table) + + def create_table_chain_completed_followup_jobs( + self, + table_chain: Sequence[TTableSchema], + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + ) -> List[FollowupJobRequest]: + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs # type: ignore[arg-type] ) + # Orphan removal is only supported for upsert strategy because we need a deterministic key hash. + first_table_in_chain = table_chain[0] + if first_table_in_chain.get( + "write_disposition" + ) == "merge" and not first_table_in_chain.get(NO_REMOVE_ORPHANS_HINT): + all_job_paths_ordered = [ + job.file_path + for table in table_chain + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table.get("name") + ] + root_table_file_name = FileStorage.get_file_name_from_file_path( + all_job_paths_ordered[0] + ) + jobs.append(ReferenceFollowupJobRequest(root_table_file_name, all_job_paths_ordered)) + return jobs def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LanceDBLoadJob(RunnableLoadJob): +class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema def __init__( self, file_path: str, - type_mapper: DataTypeMapper, - model_func: TextEmbeddingFunction, - fq_table_name: str, + table_schema: TTableSchema, ) -> None: super().__init__(file_path) - self._type_mapper = type_mapper - self._fq_table_name: str = fq_table_name - self._model_func = model_func self._job_client: "LanceDBClient" = None + self._table_schema: TTableSchema = table_schema def run(self) -> None: - self._db_client: DBConnection = self._job_client.db_client - self._embedding_model_func: TextEmbeddingFunction = self._model_func - self._embedding_model_dimensions: int = self._job_client.config.embedding_model_dimensions - self._id_field_name: str = self._job_client.config.id_field_name - - unique_identifiers: Sequence[str] = list_merge_identifiers(self._load_table) + db_client: DBConnection = self._job_client.db_client + fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) write_disposition: TWriteDisposition = cast( TWriteDisposition, self._load_table.get("write_disposition", "append") ) - with FileStorage.open_zipsafe_ro(self._file_path) as f: - records: List[DictStrAny] = [json.loads(line) for line in f] + with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: + arrow_table: pa.Table = pq.read_table(f) # Replace empty strings with placeholder string if OpenAI is used. # https://github.com/lancedb/lancedb/issues/1577#issuecomment-2318104218. if (self._job_client.config.embedding_model_provider == "openai") and ( source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) ): - records = [ - { - k: EMPTY_STRING_PLACEHOLDER if k in source_columns and v in ("", None) else v - for k, v in record.items() - } - for record in records - ] - - if self._load_table not in self._schema.dlt_tables(): - for record in records: - # Add reserved ID fields. - uuid_id = ( - generate_uuid(record, unique_identifiers, self._fq_table_name) - if unique_identifiers - else str(uuid.uuid4()) - ) - record.update({self._id_field_name: uuid_id}) + arrow_table = fill_empty_source_column_values_with_placeholder( + arrow_table, source_columns, EMPTY_STRING_PLACEHOLDER + ) - # LanceDB expects all fields in the target arrow table to be present in the data payload. - # We add and set these missing fields, that are fields not present in the target schema, to NULL. - missing_fields = set(self._load_table["columns"]) - set(record) - for field in missing_fields: - record[field] = None + # We need upsert merge's deterministic _dlt_id to perform orphan removal. + # Hence, we require at least a primary key on the root table if the merge disposition is chosen. + if ( + (self._load_table not in self._schema.dlt_table_names()) + and not is_nested_table(self._load_table) # Is root table. + and (write_disposition == "merge") + and (not get_columns_names_with_prop(self._load_table, "primary_key")) + ): + raise DestinationTerminalException( + "LanceDB's write disposition requires at least one explicit primary key." + ) - upload_batch( - records, - db_client=self._db_client, - table_name=self._fq_table_name, + dlt_id = self._schema.naming.normalize_identifier( + self._schema.data_item_normalizer.c_dlt_id # type: ignore[attr-defined] + ) + write_records( + arrow_table, + db_client=db_client, + table_name=fq_table_name, write_disposition=write_disposition, - id_field_name=self._id_field_name, + merge_key=dlt_id, ) + + +class LanceDBRemoveOrphansJob(RunnableLoadJob): + orphaned_ids: Set[str] + + def __init__( + self, + file_path: str, + ) -> None: + super().__init__(file_path) + self._job_client: "LanceDBClient" = None + self.references = ReferenceFollowupJobRequest.resolve_references(file_path) + + def run(self) -> None: + dlt_load_id = self._schema.data_item_normalizer.c_dlt_load_id # type: ignore[attr-defined] + dlt_id = self._schema.data_item_normalizer.c_dlt_id # type: ignore[attr-defined] + dlt_root_id = self._schema.data_item_normalizer.c_dlt_root_id # type: ignore[attr-defined] + + db_client: DBConnection = self._job_client.db_client + table_lineage: TTableLineage = [ + TableJob( + table_schema=self._schema.get_table( + ParsedLoadJobFileName.parse(file_path_).table_name + ), + table_name=ParsedLoadJobFileName.parse(file_path_).table_name, + file_path=file_path_, + ) + for file_path_ in self.references + ] + + for job in table_lineage: + target_is_root_table = not is_nested_table(job.table_schema) + fq_table_name = self._job_client.make_qualified_table_name(job.table_name) + file_path = job.file_path + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + + if target_is_root_table: + canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( + job.table_schema + ) + filter_condition = create_filter_condition( + canonical_doc_id_field, payload_arrow_table[canonical_doc_id_field] + ) + merge_key = dlt_load_id + + else: + filter_condition = create_filter_condition( + dlt_root_id, + payload_arrow_table[dlt_root_id], + ) + merge_key = dlt_id + + write_records( + payload_arrow_table, + db_client=db_client, + table_name=fq_table_name, + write_disposition="merge", + merge_key=merge_key, + remove_orphans=True, + filter_condition=filter_condition, + ) diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index 27c6fb33a1..25dfbc840a 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -1,6 +1,5 @@ """Utilities for creating arrow schemas from table schemas.""" - -from dlt.common.json import json +from collections import namedtuple from typing import ( List, cast, @@ -11,17 +10,19 @@ from lancedb.embeddings import TextEmbeddingFunction # type: ignore from typing_extensions import TypeAlias +from dlt.common.destination.capabilities import DataTypeMapper +from dlt.common.json import json from dlt.common.schema import Schema, TColumnSchema from dlt.common.typing import DictStrAny -from dlt.common.destination.capabilities import DataTypeMapper - TArrowSchema: TypeAlias = pa.Schema TArrowDataType: TypeAlias = pa.DataType TArrowField: TypeAlias = pa.Field NULL_SCHEMA: TArrowSchema = pa.schema([]) """Empty pyarrow Schema with no fields.""" +TableJob = namedtuple("TableJob", ["table_schema", "table_name", "file_path"]) +TTableLineage: TypeAlias = List[TableJob] def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: @@ -42,7 +43,6 @@ def make_arrow_table_schema( table_name: str, schema: Schema, type_mapper: DataTypeMapper, - id_field_name: Optional[str] = None, vector_field_name: Optional[str] = None, embedding_fields: Optional[List[str]] = None, embedding_model_func: Optional[TextEmbeddingFunction] = None, @@ -51,9 +51,6 @@ def make_arrow_table_schema( """Creates a PyArrow schema from a dlt schema.""" arrow_schema: List[TArrowField] = [] - if id_field_name: - arrow_schema.append(pa.field(id_field_name, pa.string())) - if embedding_fields: # User's provided dimension config, if provided, takes precedence. vec_size = embedding_model_dimensions or embedding_model_func.ndims() @@ -83,3 +80,22 @@ def make_arrow_table_schema( metadata["embedding_functions"] = json.dumps(embedding_functions).encode("utf-8") return pa.schema(arrow_schema, metadata=metadata) + + +def arrow_datatype_to_fusion_datatype(arrow_type: TArrowSchema) -> str: + type_map = { + pa.bool_(): "BOOLEAN", + pa.int64(): "BIGINT", + pa.float64(): "DOUBLE", + pa.utf8(): "STRING", + pa.binary(): "BYTEA", + pa.date32(): "DATE", + } + + if isinstance(arrow_type, pa.Decimal128Type): + return f"DECIMAL({arrow_type.precision}, {arrow_type.scale})" + + if isinstance(arrow_type, pa.TimestampType): + return "TIMESTAMP" + + return type_map.get(arrow_type, "UNKNOWN") diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index aeacd4d34b..56991b090f 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,13 +1,16 @@ import os -import uuid -from typing import Sequence, Union, Dict +from typing import Union, Dict, List +import pyarrow as pa + +from dlt.common import logger +from dlt.common.data_writers.escape import escape_lancedb_literal +from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.schema import TTableSchema -from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.typing import DictStrAny +from dlt.common.schema.utils import get_columns_names_with_prop, get_first_column_name_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider - +EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { "cohere": "COHERE_API_KEY", "gemini-text": "GOOGLE_API_KEY", @@ -16,40 +19,55 @@ } -def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: - """Generates deterministic UUID - used for deduplication. +def set_non_standard_providers_environment_variables( + embedding_model_provider: TEmbeddingProvider, api_key: Union[str, None] +) -> None: + if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: + os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" - Args: - data (Dict[str, Any]): Arbitrary data to generate UUID for. - unique_identifiers (Sequence[str]): A list of unique identifiers. - table_name (str): LanceDB table name. - Returns: - str: A string representation of the generated UUID. - """ - data_id = "_".join(str(data[key]) for key in unique_identifiers) - return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) +def get_canonical_vector_database_doc_id_merge_key( + load_table: TTableSchema, +) -> str: + if merge_key := get_first_column_name_with_prop(load_table, "merge_key"): + return merge_key + elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): + # No merge key defined, warn and assume the first element of the primary key is `doc_id`. + logger.warning( + "Merge strategy selected without defined merge key - using the first element of the" + f" primary key ({primary_key}) as merge key." + ) + return primary_key[0] + else: + raise DestinationTerminalException( + "You must specify at least a primary key in order to perform orphan removal." + ) -def list_merge_identifiers(table_schema: TTableSchema) -> Sequence[str]: - """Returns a list of merge keys for a table used for either merging or deduplication. +def fill_empty_source_column_values_with_placeholder( + table: pa.Table, source_columns: List[str], placeholder: str +) -> pa.Table: + """ + Replaces empty strings and null values in the specified source columns of an Arrow table with a placeholder string. Args: - table_schema (TTableSchema): a dlt table schema. + table (pa.Table): The input Arrow table. + source_columns (List[str]): A list of column names to replace empty strings and null values in. + placeholder (str): The placeholder string to use for replacement. Returns: - Sequence[str]: A list of unique column identifiers. + pa.Table: The modified Arrow table with empty strings and null values replaced in the specified columns. """ - if table_schema.get("write_disposition") == "merge": - primary_keys = get_columns_names_with_prop(table_schema, "primary_key") - merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys)): - return join_keys - return get_columns_names_with_prop(table_schema, "unique") + for col_name in source_columns: + column = table[col_name] + filled_column = pa.compute.fill_null(column, fill_value=placeholder) + new_column = pa.compute.replace_substring_regex( + filled_column, pattern=r"^$", replacement=placeholder + ) + table = table.set_column(table.column_names.index(col_name), col_name, new_column) + return table -def set_non_standard_providers_environment_variables( - embedding_model_provider: TEmbeddingProvider, api_key: Union[str, None] -) -> None: - if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: - os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" +def create_filter_condition(field_name: str, array: pa.Array) -> str: + array_py = array.to_pylist() + return f"{field_name} IN ({', '.join(map(escape_lancedb_literal, array_py))})" diff --git a/dlt/destinations/impl/qdrant/qdrant_adapter.py b/dlt/destinations/impl/qdrant/qdrant_adapter.py index abe301fff0..bbc2d719a8 100644 --- a/dlt/destinations/impl/qdrant/qdrant_adapter.py +++ b/dlt/destinations/impl/qdrant/qdrant_adapter.py @@ -34,7 +34,7 @@ def qdrant_adapter( """ resource = get_resource_for_adapter(data) - column_hints: TTableSchemaColumns = {} + column_hints: TTableSchemaColumns = None if embed: if isinstance(embed, str): @@ -44,6 +44,7 @@ def qdrant_adapter( "embed must be a list of column names or a single column name as a string" ) + column_hints = {} for column_name in embed: column_hints[column_name] = { "name": column_name, diff --git a/dlt/destinations/impl/weaviate/weaviate_adapter.py b/dlt/destinations/impl/weaviate/weaviate_adapter.py index 9bd0b41783..0ca9047528 100644 --- a/dlt/destinations/impl/weaviate/weaviate_adapter.py +++ b/dlt/destinations/impl/weaviate/weaviate_adapter.py @@ -87,6 +87,7 @@ def weaviate_adapter( TOKENIZATION_HINT: method, # type: ignore } + # this makes sure that {} as column_hints never gets into apply_hints (that would reset existing columns) if not column_hints: raise ValueError("Either 'vectorize' or 'tokenization' must be specified.") else: diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index ae27213a7c..f59f087f4f 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -78,7 +78,6 @@ def from_table_chain( job = cls(file_info.file_name()) job._save_text_file("\n".join(sql)) except Exception as e: - # raise exception with some context raise SqlJobCreationException(e, table_chain) from e return job diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index a85d6ddc7e..b2aec665ab 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -179,19 +179,61 @@ info = pipeline.run( ### Merge -The [merge](../../general-usage/incremental-loading.md) write disposition merges the data from the resource with the data at the destination based on a unique identifier. +The [merge](../../general-usage/incremental-loading.md) write disposition merges the data from the resource with the data at the destination based on a unique identifier. The LanceDB destination merge write disposition only supports upsert strategy. This updates existing records and inserts new ones based on a unique identifier. + +You can specify the merge disposition, primary key, and merge key either in a resource or adapter: + +```py +@dlt.resource( + primary_key=["doc_id", "chunk_id"], + merge_key=["doc_id"], + write_disposition={"disposition": "merge", "strategy": "upsert"}, +) +def my_rag_docs( + data: List[DictStrAny], +) -> Generator[List[DictStrAny], None, None]: + yield data +``` + +Or: + +```py +pipeline.run( + lancedb_adapter( + my_new_rag_docs, + merge_key="doc_id" + ), + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_id"], +) +``` + +The `primary_key` uniquely identifies each record, typically comprising a document ID and a chunk ID. +The `merge_key`, which cannot be compound, should correspond to the canonical `doc_id` used in vector databases and represent the document identifier in your data model. +It must be the first element of the `primary_key`. +This `merge_key` is crucial for document identification and orphan removal during merge operations. +This structure ensures proper record identification and maintains consistency with vector database concepts. + + +#### Orphan Removal + +LanceDB **automatically removes orphaned chunks** when updating or deleting parent documents during a merge operation. To disable this feature: ```py pipeline.run( lancedb_adapter( movies, embed="title", + no_remove_orphans=True # Disable with the `no_remove_orphans` flag. ), - write_disposition="merge", - primary_key="id", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_id"], ) ``` +Note: While it's possible to omit the `merge_key` for brevity (in which case it is assumed to be the first entry of `primary_key`), +explicitly specifying both is recommended for clarity. + ### Append This is the default disposition. It will append the data to the existing data in the destination. @@ -200,7 +242,6 @@ This is the default disposition. It will append the data to the existing data in - `dataset_separator`: The character used to separate the dataset name from table names. Defaults to "___". - `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector". -- `id_field_name`: The name of the special field used for deduplication and merging. Defaults to "id__". - `max_retries`: The maximum number of retries for embedding operations. Set to 0 to disable retries. Defaults to 3. ## dbt support diff --git a/poetry.lock b/poetry.lock index 5d81d06969..f4667da374 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2167,32 +2167,33 @@ typing-extensions = ">=3.10.0" [[package]] name = "databricks-sql-connector" -version = "3.3.0" +version = "2.9.6" description = "Databricks SQL Connector for Python" optional = true -python-versions = "<4.0.0,>=3.8.0" +python-versions = "<4.0.0,>=3.7.1" files = [ - {file = "databricks_sql_connector-3.3.0-py3-none-any.whl", hash = "sha256:55ee5a4a11291bf91a235ac76e41b419ddd66a9a321065a8bfaf119acbb26d6b"}, - {file = "databricks_sql_connector-3.3.0.tar.gz", hash = "sha256:19e82965da4c86574adfe9f788c17b4494d98eb8075ba4fd4306573d2edbf194"}, + {file = "databricks_sql_connector-2.9.6-py3-none-any.whl", hash = "sha256:d830abf86e71d2eb83c6a7b7264d6c03926a8a83cec58541ddd6b83d693bde8f"}, + {file = "databricks_sql_connector-2.9.6.tar.gz", hash = "sha256:e55f5b8ede8ae6c6f31416a4cf6352f0ac019bf6875896c668c7574ceaf6e813"}, ] [package.dependencies] +alembic = ">=1.0.11,<2.0.0" lz4 = ">=4.0.2,<5.0.0" numpy = [ - {version = ">=1.16.6,<2.0.0", markers = "python_version >= \"3.8\" and python_version < \"3.11\""}, - {version = ">=1.23.4,<2.0.0", markers = "python_version >= \"3.11\""}, + {version = ">=1.16.6", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=1.23.4", markers = "python_version >= \"3.11\""}, ] oauthlib = ">=3.1.0,<4.0.0" openpyxl = ">=3.0.10,<4.0.0" -pandas = {version = ">=1.2.5,<2.2.0", markers = "python_version >= \"3.8\""} -pyarrow = ">=14.0.1,<17" +pandas = {version = ">=1.2.5,<3.0.0", markers = "python_version >= \"3.8\""} +pyarrow = [ + {version = ">=6.0.0", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=10.0.1", markers = "python_version >= \"3.11\""}, +] requests = ">=2.18.1,<3.0.0" -thrift = ">=0.16.0,<0.21.0" -urllib3 = ">=1.26" - -[package.extras] -alembic = ["alembic (>=1.0.11,<2.0.0)", "sqlalchemy (>=2.0.21)"] -sqlalchemy = ["sqlalchemy (>=2.0.21)"] +sqlalchemy = ">=1.3.24,<2.0.0" +thrift = ">=0.16.0,<0.17.0" +urllib3 = ">=1.0" [[package]] name = "db-dtypes" @@ -6760,52 +6761,55 @@ files = [ [[package]] name = "pyarrow" -version = "16.1.0" +version = "17.0.0" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, - {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, - {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, - {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, - {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, - {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, - {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, - {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, + {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, + {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, + {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] numpy = ">=1.16.6" +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + [[package]] name = "pyasn1" version = "0.5.0" @@ -9867,4 +9871,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "97666ad4613f07d95c5388bae41befe6cc10c88d02ee8f1cee27b161e13729f1" +content-hash = "6393c17b6865a78b3adb2732e4b2e416b7dc869f07649e60cc15352cab49444f" diff --git a/pyproject.toml b/pyproject.toml index 165364908c..38262a95d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -239,7 +239,7 @@ dbt-duckdb = ">=1.2.0" pymongo = ">=4.3.3" pandas = ">2" alive-progress = ">=3.0.1" -pyarrow = ">=14.0.0" +pyarrow = ">=17.0.0" psycopg2-binary = ">=2.9" lancedb = { version = ">=0.8.2", markers = "python_version >= '3.9'", allow-prereleases = true } openai = ">=1.45" diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py new file mode 100644 index 0000000000..f04c846df7 --- /dev/null +++ b/tests/load/lancedb/test_merge.py @@ -0,0 +1,425 @@ +from typing import Iterator, List, Generator, Any + +import numpy as np +import pandas as pd +import pytest +from lancedb.table import Table # type: ignore +from pandas import DataFrame +from pandas.testing import assert_frame_equal + +import dlt +from dlt.common.typing import DictStrAny, DictStrStr +from dlt.common.utils import uniq_id +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + lancedb_adapter, +) +from tests.load.lancedb.utils import chunk_document +from tests.load.utils import ( + drop_active_pipeline_data, + sequence_generator, +) +from tests.pipeline.utils import ( + assert_load_info, +) + + +# Mark all tests as essential, don't remove. +pytestmark = pytest.mark.essential + + +@pytest.fixture(autouse=True) +def drop_lancedb_data() -> Iterator[None]: + yield + drop_active_pipeline_data() + + +def test_lancedb_remove_nested_orphaned_records() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="parent", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key="id", + merge_key="id", + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + run_1 = [ + { + "id": 1, + "child": [ + {"bar": 1, "grandchild": [{"baz": 1}, {"baz": 2}]}, + {"bar": 2, "grandchild": [{"baz": 3}]}, + ], + }, + {"id": 2, "child": [{"bar": 3, "grandchild": [{"baz": 4}]}]}, + { + "id": 3, + "child": [ + {"bar": 10, "grandchild": [{"baz": 5}]}, + {"bar": 11, "grandchild": [{"baz": 6}, {"baz": 7}]}, + ], + }, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + { + "id": 1, + "child": [{"bar": 1, "grandchild": [{"baz": 1}]}], + }, # Removes bar_2, baz_2 and baz_3. + { + "id": 2, + "child": [{"bar": 4, "grandchild": [{"baz": 8}]}], + }, # Removes bar_3, baz_4. + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_parent_data = pd.DataFrame( + data=[ + {"id": 1}, + {"id": 2}, + {"id": 3}, + ] + ) + + expected_child_data = pd.DataFrame( + data=[ + {"bar": 1}, + {"bar": 4}, + {"bar": 10}, + {"bar": 11}, + ] + ) + + expected_grandchild_data = pd.DataFrame( + data=[ + {"baz": 1}, + {"baz": 8}, + {"baz": 5}, + {"baz": 6}, + {"baz": 7}, + ] + ) + + parent_table_name = client.make_qualified_table_name("parent") # type: ignore[attr-defined] + child_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] + grandchild_table_name = client.make_qualified_table_name( # type: ignore[attr-defined] + "parent__child__grandchild" + ) + + parent_tbl = client.db_client.open_table(parent_table_name) # type: ignore[attr-defined] + child_tbl = client.db_client.open_table(child_table_name) # type: ignore[attr-defined] + grandchild_tbl = client.db_client.open_table(grandchild_table_name) # type: ignore[attr-defined] + + actual_parent_df = parent_tbl.to_pandas().sort_values(by="id").reset_index(drop=True) + actual_child_df = child_tbl.to_pandas().sort_values(by="bar").reset_index(drop=True) + actual_grandchild_df = ( + grandchild_tbl.to_pandas().sort_values(by="baz").reset_index(drop=True) + ) + + expected_parent_data = expected_parent_data.sort_values(by="id").reset_index(drop=True) + expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) + expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( + drop=True + ) + + assert_frame_equal(actual_parent_df[["id"]], expected_parent_data) + assert_frame_equal(actual_child_df[["bar"]], expected_child_data) + assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) + + +def test_lancedb_remove_orphaned_records_root_table() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_root_table", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + lancedb_adapter(identity_resource) + + run_1 = [ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2a"}, + {"doc_id": 2, "chunk_hash": "2b"}, + {"doc_id": 2, "chunk_hash": "2c"}, + {"doc_id": 3, "chunk_hash": "3a"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) + )[["doc_id", "chunk_hash"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_lancedb_remove_orphaned_records_root_table_string_doc_id() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_root_table", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + lancedb_adapter(identity_resource) + + run_1 = [ + {"doc_id": "A", "chunk_hash": "1a"}, + {"doc_id": "B", "chunk_hash": "2a"}, + {"doc_id": "B", "chunk_hash": "2b"}, + {"doc_id": "B", "chunk_hash": "2c"}, + {"doc_id": "C", "chunk_hash": "3a"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": "B", "chunk_hash": "2d"}, + {"doc_id": "B", "chunk_hash": "2e"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": "A", "chunk_hash": "1a"}, + {"doc_id": "B", "chunk_hash": "2d"}, + {"doc_id": "B", "chunk_hash": "2e"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) + )[["doc_id", "chunk_hash"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: + @dlt.resource( + write_disposition={"disposition": "merge", "strategy": "upsert"}, + table_name="document", + primary_key=["doc_id", "chunk"], + merge_key="doc_id", + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + for chunk in chunk_document(doc["text"]): + yield {"doc_id": doc_id, "doc_text": doc["text"], "chunk": chunk} + + @dlt.source() + def documents_source( + docs: List[DictStrAny], + ) -> Any: + return documents(docs) + + lancedb_adapter( + documents, + embed=["chunk"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_with_embeddings", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_{uniq_id()}", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ + { + "text": "This is the first document, but it has been updated with new content.", + "doc_id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "doc_id": 3, + }, + ] + + info = pipeline.run(documents_source(updated_docs)) + assert_load_info(info) + + with pipeline.destination_client() as client: + embeddings_table_name = client.make_qualified_table_name("document") # type: ignore[attr-defined] + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + + # Check (non-empty) embeddings as present, and that orphaned embeddings have been discarded. + assert len(df) == 21 + assert "vector" in df.columns + for _, vector in enumerate(df["vector"]): + assert isinstance(vector, np.ndarray) + assert vector.size > 0 + + +def test_lancedb_compound_merge_key_root_table() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_compound_merge_key", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id", "chunk_hash"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + lancedb_adapter(identity_resource, no_remove_orphans=True) + + run_1 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash", "foo"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash", "foo"]).reset_index(drop=True) + )[["doc_id", "chunk_hash", "foo"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_must_provide_at_least_primary_key_on_merge_disposition() -> None: + """We need upsert merge's deterministic _dlt_id to perform orphan removal. + Hence, we require at least the primary key required (raises exception if missing). + Specify a merge key for custom orphan identification.""" + generator_instance1 = sequence_generator() + + @dlt.resource(write_disposition={"disposition": "merge", "strategy": "upsert"}) + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + pipeline = dlt.pipeline( + pipeline_name="test_must_provide_both_primary_and_merge_key_on_merge_disposition", + destination="lancedb", + dataset_name=( + f"test_must_provide_both_primary_and_merge_key_on_merge_disposition{uniq_id()}" + ), + ) + with pytest.raises(Exception): + load_info = pipeline.run( + some_data(), + ) + assert_load_info(load_info) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 6cd0abd587..345934fb29 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,25 +1,30 @@ import multiprocessing -from typing import Iterator, Generator, Any, List, Mapping +import os +from typing import Iterator, Generator, Any, List +from typing import Mapping +from typing import Union, Dict import pytest -import lancedb # type: ignore -from lancedb import DBConnection +from lancedb import DBConnection # type: ignore from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore +from lancedb.table import Table # type: ignore import dlt from dlt.common import json -from dlt.common.typing import DictStrStr, DictStrAny -from dlt.common.utils import uniq_id +from dlt.common.typing import DictStrAny +from dlt.common.typing import DictStrStr +from dlt.common.utils import uniq_id, digest128 from dlt.destinations.impl.lancedb.lancedb_adapter import ( lancedb_adapter, VECTORIZE_HINT, ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient -from tests.load.lancedb.utils import assert_table +from dlt.extract import DltResource +from tests.load.lancedb.utils import assert_table, chunk_document, mock_embed from tests.load.utils import sequence_generator, drop_active_pipeline_data from tests.pipeline.utils import assert_load_info -# Mark all tests as essential, do not remove. +# Mark all tests as essential, don't remove. pytestmark = pytest.mark.essential @@ -49,6 +54,22 @@ def some_data() -> Generator[DictStrStr, Any, None]: "x-lancedb-embed": True, } + lancedb_adapter( + some_data, + merge_key="content", + ) + + # via merge_key + assert some_data._hints["merge_key"] == "content" + + assert some_data.columns["content"] == { # type: ignore + "name": "content", + "data_type": "text", + "x-lancedb-embed": True, + } + + assert some_data.compute_table_schema()["columns"]["content"]["merge_key"] is True + def test_basic_state_and_schema() -> None: generator_instance1 = sequence_generator() @@ -118,14 +139,13 @@ def some_data() -> Generator[DictStrStr, Any, None]: def test_explicit_append() -> None: - """Append should work even when the primary key is specified.""" data = [ {"doc_id": 1, "content": "1"}, {"doc_id": 2, "content": "2"}, {"doc_id": 3, "content": "3"}, ] - @dlt.resource(primary_key="doc_id") + @dlt.resource() def some_data() -> Generator[List[DictStrAny], Any, None]: yield data @@ -142,6 +162,7 @@ def some_data() -> Generator[List[DictStrAny], Any, None]: info = pipeline.run( some_data(), ) + assert_load_info(info) assert_table(pipeline, "some_data", items=data) @@ -156,25 +177,22 @@ def some_data() -> Generator[List[DictStrAny], Any, None]: def test_pipeline_replace() -> None: - generator_instance1 = sequence_generator() - generator_instance2 = sequence_generator() + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "2" + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "2" + + generator_instance1, generator_instance2 = (sequence_generator(), sequence_generator()) @dlt.resource def some_data() -> Generator[DictStrStr, Any, None]: yield from next(generator_instance1) - lancedb_adapter( - some_data, - embed=["content"], - ) - uid = uniq_id() pipeline = dlt.pipeline( pipeline_name="test_pipeline_replace", destination="lancedb", dataset_name="test_pipeline_replace_dataset" - + uid, # lancedb doesn't mandate any name normalization + + uid, # Lancedb doesn't mandate any name normalization. ) info = pipeline.run( @@ -263,23 +281,11 @@ def test_pipeline_merge() -> None: }, ] - @dlt.resource(primary_key="doc_id") + @dlt.resource(primary_key=["doc_id"]) def movies_data() -> Any: yield data - @dlt.resource(primary_key="doc_id", merge_key=["merge_id", "title"]) - def movies_data_explicit_merge_keys() -> Any: - yield data - - lancedb_adapter( - movies_data, - embed=["description"], - ) - - lancedb_adapter( - movies_data_explicit_merge_keys, - embed=["description"], - ) + lancedb_adapter(movies_data, embed=["description"], no_remove_orphans=True) pipeline = dlt.pipeline( pipeline_name="movies", @@ -288,7 +294,7 @@ def movies_data_explicit_merge_keys() -> Any: ) info = pipeline.run( movies_data(), - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, dataset_name=f"MoviesDataset{uniq_id()}", ) assert_load_info(info) @@ -299,26 +305,11 @@ def movies_data_explicit_merge_keys() -> Any: info = pipeline.run( movies_data(), - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, ) assert_load_info(info) assert_table(pipeline, "movies_data", items=data) - info = pipeline.run( - movies_data(), - write_disposition="merge", - ) - assert_load_info(info) - assert_table(pipeline, "movies_data", items=data) - - # Test with explicit merge keys. - info = pipeline.run( - movies_data_explicit_merge_keys(), - write_disposition="merge", - ) - assert_load_info(info) - assert_table(pipeline, "movies_data_explicit_merge_keys", items=data) - def test_pipeline_with_schema_evolution() -> None: data = [ @@ -388,9 +379,9 @@ def test_merge_github_nested() -> None: data = json.load(f) info = pipe.run( - lancedb_adapter(data[:17], embed=["title", "body"]), + lancedb_adapter(data[:17], embed=["title", "body"], no_remove_orphans=True), table_name="issues", - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, primary_key="id", ) assert_load_info(info) @@ -426,18 +417,116 @@ def test_merge_github_nested() -> None: def test_empty_dataset_allowed() -> None: # dataset_name is optional so dataset name won't be autogenerated when not explicitly passed. pipe = dlt.pipeline(destination="lancedb", dev_mode=True) - client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) # Dataset in load info is empty. assert info.dataset_name is None - client = pipe.destination_client() # type: ignore[assignment] - assert client.dataset_name is None - assert client.sentinel_table == "dltSentinelTable" + client = pipe.destination_client() + assert client.dataset_name is None # type: ignore + assert client.sentinel_table == "dltSentinelTable" # type: ignore assert_table(pipe, "content", expected_items_count=3) +def test_lancedb_remove_nested_orphaned_records_with_chunks() -> None: + @dlt.resource( + write_disposition={"disposition": "merge", "strategy": "upsert"}, + table_name="document", + primary_key=["doc_id"], + merge_key=["doc_id"], + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + chunks = chunk_document(doc["text"]) + embeddings = [ + { + "chunk_hash": digest128(chunk), + "chunk_text": chunk, + "embedding": mock_embed(), + } + for chunk in chunks + ] + yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} + + @dlt.source(max_table_nesting=1) + def documents_source( + docs: List[DictStrAny], + ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: + return documents(docs) + + pipeline = dlt.pipeline( + pipeline_name="chunked_docs", + destination="lancedb", + dataset_name="chunked_documents", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ + { + "text": "This is the first document, but it has been updated with new content.", + "doc_id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "doc_id": 3, + }, + ] + + info = pipeline.run(documents_source(updated_docs)) + assert_load_info(info) + + with pipeline.destination_client() as client: + # Orphaned chunks/documents must have been discarded. + # Shouldn't contain any text from `initial_docs' where doc_id=1. + expected_text = { + "Here's ano", + "ther docum", + "ent. It's ", + "a bit diff", + "erent from", + " the first", + " one.", + "This is th", + "e first do", + "cument, bu", + "t it has b", + "een update", + "d with new", + " content.", + "This is a ", + "completely", + " new docum", + "ent that w", + "asn't in t", + "he initial", + " set.", + } + + embeddings_table_name = client.make_qualified_table_name("document__embeddings") # type: ignore[attr-defined] + + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + assert set(df["chunk_text"]) == expected_text + + search_data = [ {"text": "Frodo was a happy puppy"}, {"text": "There are several kittens playing"}, diff --git a/tests/load/lancedb/test_utils.py b/tests/load/lancedb/test_utils.py new file mode 100644 index 0000000000..d7f9729f26 --- /dev/null +++ b/tests/load/lancedb/test_utils.py @@ -0,0 +1,46 @@ +import pyarrow as pa +import pytest + +from dlt.destinations.impl.lancedb.utils import ( + create_filter_condition, + fill_empty_source_column_values_with_placeholder, +) + + +# Mark all tests as essential, don't remove. +pytestmark = pytest.mark.essential + + +def test_fill_empty_source_column_values_with_placeholder() -> None: + data = [ + pa.array(["", "hello", ""]), + pa.array(["hello", None, ""]), + pa.array([1, 2, 3]), + pa.array(["world", "", "arrow"]), + ] + table = pa.Table.from_arrays(data, names=["A", "B", "C", "D"]) + + source_columns = ["A", "B"] + placeholder = "placeholder" + + new_table = fill_empty_source_column_values_with_placeholder(table, source_columns, placeholder) + + expected_data = [ + pa.array(["placeholder", "hello", "placeholder"]), + pa.array(["hello", "placeholder", "placeholder"]), + pa.array([1, 2, 3]), + pa.array(["world", "", "arrow"]), + ] + expected_table = pa.Table.from_arrays(expected_data, names=["A", "B", "C", "D"]) + assert new_table.equals(expected_table) + + +def test_create_filter_condition() -> None: + assert ( + create_filter_condition("_dlt_load_id", pa.array(["A", "B", "C'c\n"])) + == "_dlt_load_id IN ('A', 'B', 'C''c\\n')" + ) + assert ( + create_filter_condition("_dlt_load_id", pa.array([1.2, 3, 5 / 2])) + == "_dlt_load_id IN (1.2, 3.0, 2.5)" + ) diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index 7431e895b7..30430fe076 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -40,7 +40,7 @@ def assert_table( exists = client.table_exists(qualified_table_name) assert exists - records = client.db_client.open_table(qualified_table_name).search().limit(50).to_list() + records = client.db_client.open_table(qualified_table_name).search().limit(0).to_list() if expected_items_count is not None: assert expected_items_count == len(records) @@ -51,7 +51,6 @@ def assert_table( drop_keys = [ "_dlt_id", "_dlt_load_id", - dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector", ] objects_without_dlt_or_special_keys = [ @@ -72,3 +71,13 @@ def generate_embeddings( def ndims(self) -> int: return 2 + + +def mock_embed( + dim: int = 10, +) -> str: + return str(np.random.random_sample(dim)) + + +def chunk_document(doc: str, chunk_size: int = 10) -> List[str]: + return [doc[i : i + chunk_size] for i in range(0, len(doc), chunk_size)] diff --git a/tests/load/utils.py b/tests/load/utils.py index 3edf111a36..e7d6476a3a 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -1004,7 +1004,7 @@ def prepare_load_package( def sequence_generator() -> Generator[List[Dict[str, str]], None, None]: count = 1 while True: - yield [{"content": str(count + i)} for i in range(3)] + yield [{"content": str(count + i)} for i in range(2000)] count += 3