From 68e26a0587e79e9e2cb61d3649086e3108911fff Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 16 Jul 2024 20:56:18 +0200 Subject: [PATCH 001/113] Add tests for LanceDB chunking and merging functionality Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_pipeline.py | 89 ++++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index e817a2f6c8..3f2f1298ba 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -16,7 +16,7 @@ from tests.pipeline.utils import assert_load_info -# Mark all tests as essential, do not remove. +# Mark all tests as essential, don't remove. pytestmark = pytest.mark.essential @@ -426,10 +426,95 @@ def test_empty_dataset_allowed() -> None: client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None - info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) + info = pipe.run( + lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]) + ) # Dataset in load info is empty. assert info.dataset_name is None client = pipe.destination_client() # type: ignore[assignment] assert client.dataset_name is None assert client.sentinel_table == "dltSentinelTable" assert_table(pipe, "content", expected_items_count=3) + + +docs = [ + [ + { + "text": "This is the first document. It contains some text that will be chunked and embedded. (I don't want " + "to be seen in updated run's embedding chunk texts btw)", + "id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "id": 2, + }, + ], + [ + { + "text": "This is the first document, but it has been updated with new content.", + "id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "id": 3, + }, + ], +] + + +def splitter(text: str, chunk_size: int = 10) -> List[str]: + return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)] + + +def test_chunking_no_splitter() -> None: + pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) + info = pipe.run( + docs[0], + table_name="documents", + ) + assert_load_info(info) + + # TODO: Check and compare output + + +def test_chunking_with_splitter() -> None: + pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) + + info = pipe.run( + lancedb_adapter(docs[0], embed="text", splitter=splitter), + table_name="documents", + ) + assert_load_info(info) + + # TODO: Check and compare output + + +def test_chunk_merge() -> None: + """Test chunking is applied without orphaned chunks when new documents arrive.""" + + pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) + + + info = pipe.run( + lancedb_adapter(docs[0], embed="text", splitter=splitter), + table_name="documents", + write_disposition="merge", + primary_key="id", + ) + pipe.run(info) + + # Orphaned chunks must be discarded. + info = pipe.run( + lancedb_adapter(docs[1], embed="text", splitter=splitter), + table_name="documents", + write_disposition="merge", + primary_key="id", + ) + assert_load_info(info) + + # TODO: Check and compare output + + +def test_embedding_provider_only_called_once_per_chunk_hash() -> None: + """Verify that the embedding provider is called only once for each unique chunk hash to optimize API usage and reduce costs.""" + raise NotImplementedError From 900c4faabf346587ad558439600f464cabe33486 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 18 Jul 2024 18:32:58 +0200 Subject: [PATCH 002/113] Add TSplitter type alias for LanceDB document splitting function Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/typing.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 dlt/destinations/impl/lancedb/typing.py diff --git a/dlt/destinations/impl/lancedb/typing.py b/dlt/destinations/impl/lancedb/typing.py new file mode 100644 index 0000000000..c97b23fa7c --- /dev/null +++ b/dlt/destinations/impl/lancedb/typing.py @@ -0,0 +1,4 @@ +from collections.abc import Callable +from typing import List, Any + +TSplitter = Callable[[str, Any], List[str | Any]] From 16230a70506edac7c676cb25cbda2a3078be6e11 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 18 Jul 2024 18:57:11 +0200 Subject: [PATCH 003/113] Refine typing for chunks Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/typing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/typing.py b/dlt/destinations/impl/lancedb/typing.py index c97b23fa7c..bf56955b79 100644 --- a/dlt/destinations/impl/lancedb/typing.py +++ b/dlt/destinations/impl/lancedb/typing.py @@ -1,4 +1,6 @@ -from collections.abc import Callable -from typing import List, Any +from typing import Callable, Union, List, Dict, Any -TSplitter = Callable[[str, Any], List[str | Any]] +ChunkInputT = Union[str, Dict[str, Any], Any] +ChunkOutputT = List[Any] + +TSplitter = Callable[[ChunkInputT, Any], ChunkOutputT] From 3f7a82f972561266d14e3e6c4e6880f1757ac5ff Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 19 Jul 2024 12:18:14 +0200 Subject: [PATCH 004/113] Add type definitions for chunk splitter function and related types Signed-off-by: Marcel Coetzee --- dlt/common/schema/typing.py | 5 +++++ dlt/destinations/impl/lancedb/typing.py | 6 ------ 2 files changed, 5 insertions(+), 6 deletions(-) delete mode 100644 dlt/destinations/impl/lancedb/typing.py diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 9a4dd51d4b..d0ada7a778 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -75,6 +75,11 @@ TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" +ChunkInputT = Union[str, Dict[str, Any], Any] +ChunkOutputT = List[Any] +TSplitter = Callable[[ChunkInputT, Any], ChunkOutputT] +"""Splitter function that takes a ChunkInputT and any additional arguments, returning a ChunkOutputT.""" + # COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) COLUMN_HINTS: Set[TColumnHint] = set( [ diff --git a/dlt/destinations/impl/lancedb/typing.py b/dlt/destinations/impl/lancedb/typing.py deleted file mode 100644 index bf56955b79..0000000000 --- a/dlt/destinations/impl/lancedb/typing.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Callable, Union, List, Dict, Any - -ChunkInputT = Union[str, Dict[str, Any], Any] -ChunkOutputT = List[Any] - -TSplitter = Callable[[ChunkInputT, Any], ChunkOutputT] From 1dda1d517492ff5a25b220f69b960b3e661ef210 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 19 Jul 2024 14:03:20 +0200 Subject: [PATCH 005/113] Remove unused ChunkInputT, ChunkOutputT, and TSplitter type definitions Signed-off-by: Marcel Coetzee --- dlt/common/schema/typing.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index d0ada7a778..9a4dd51d4b 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -75,11 +75,6 @@ TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" -ChunkInputT = Union[str, Dict[str, Any], Any] -ChunkOutputT = List[Any] -TSplitter = Callable[[ChunkInputT, Any], ChunkOutputT] -"""Splitter function that takes a ChunkInputT and any additional arguments, returning a ChunkOutputT.""" - # COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) COLUMN_HINTS: Set[TColumnHint] = set( [ From 48e14ab64f2c7fb57c0ec748d2f3647811dff53e Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 21 Jul 2024 20:43:35 +0200 Subject: [PATCH 006/113] Implement efficient update strategy for chunked documents in LanceDB Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 41 +++++++++++++++---- tests/load/lancedb/test_pipeline.py | 31 +++----------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8265e50fbf..e3025d19bf 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,7 +1,6 @@ import uuid from types import TracebackType from typing import ( - ClassVar, List, Any, cast, @@ -17,6 +16,7 @@ import lancedb # type: ignore import pyarrow as pa +import pyarrow.compute as pc from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore @@ -77,7 +77,6 @@ else: NDArray = ndarray - TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} @@ -145,7 +144,7 @@ def from_db_type( if (precision, scale) == self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) - return super().from_db_type(db_type, precision, scale) + return super().from_db_type(cast(str, db_type), precision, scale) def upload_batch( @@ -154,7 +153,8 @@ def upload_batch( *, db_client: DBConnection, table_name: str, - write_disposition: TWriteDisposition, + parent_table_name: Optional[str] = None, + write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -163,6 +163,7 @@ def upload_batch( records: The data to be inserted as payload. db_client: The LanceDB client connection. table_name: The name of the table to insert into. + parent_table_name: The name of the parent table, if the target table has any. id_field_name: The name of the ID field for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. @@ -190,6 +191,23 @@ def upload_batch( tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) + + # Remove orphaned parent IDs. + if parent_table_name: + try: + parent_tbl = db_client.open_table(parent_table_name) + parent_tbl.checkout_latest() + except FileNotFoundError as e: + raise DestinationTransientException( + "Couldn't open lancedb database. Batch WILL BE RETRIED" + ) from e + + parent_ids = set(pc.unique(parent_tbl.to_arrow()["_dlt_id"]).to_pylist()) + child_ids = set(pc.unique(tbl.to_arrow()["_dlt_parent_id"]).to_pylist()) + + if orphaned_ids := child_ids - parent_ids: + tbl.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") + else: raise DestinationTerminalException( f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" @@ -334,7 +352,7 @@ def drop_storage(self) -> None: Deletes all tables in the dataset and all data, as well as sentinel table associated with them. - If the dataset name was not provided, it deletes all the tables in the current schema. + If the dataset name wasn't provided, it deletes all the tables in the current schema. """ for table_name in self._get_table_names(): self.db_client.drop_table(table_name) @@ -459,9 +477,6 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - embedding_fields: List[str] = get_columns_names_with_prop( - self.schema.get_table(table_name), VECTORIZE_HINT - ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: @@ -524,10 +539,12 @@ def update_schema_in_storage(self) -> None: write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) + upload_batch( records, db_client=self.db_client, table_name=fq_version_table_name, + parent_table_name=None, write_disposition=write_disposition, ) @@ -690,6 +707,8 @@ def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + parent_table = table.get("parent") + return LoadLanceDBJob( self.schema, table, @@ -699,6 +718,9 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> client_config=self.config, model_func=self.model_func, fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) if parent_table else None + ), ) def table_exists(self, table_name: str) -> bool: @@ -718,6 +740,7 @@ def __init__( client_config: LanceDBClientConfiguration, model_func: TextEmbeddingFunction, fq_table_name: str, + fq_parent_table_name: Optional[str], ) -> None: file_name = FileStorage.get_file_name_from_file_path(local_path) super().__init__(file_name) @@ -727,6 +750,7 @@ def __init__( self.type_mapper: TypeMapper = type_mapper self.table_name: str = table_schema["name"] self.fq_table_name: str = fq_table_name + self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func: TextEmbeddingFunction = model_func @@ -759,6 +783,7 @@ def __init__( records, db_client=db_client, table_name=self.fq_table_name, + parent_table_name=self.fq_parent_table_name, write_disposition=self.write_disposition, id_field_name=self.id_field_name, ) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 3f2f1298ba..bc4af48c50 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -426,9 +426,7 @@ def test_empty_dataset_allowed() -> None: client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None - info = pipe.run( - lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]) - ) + info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) # Dataset in load info is empty. assert info.dataset_name is None client = pipe.destination_client() # type: ignore[assignment] @@ -440,8 +438,10 @@ def test_empty_dataset_allowed() -> None: docs = [ [ { - "text": "This is the first document. It contains some text that will be chunked and embedded. (I don't want " - "to be seen in updated run's embedding chunk texts btw)", + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), "id": 1, }, { @@ -462,10 +462,6 @@ def test_empty_dataset_allowed() -> None: ] -def splitter(text: str, chunk_size: int = 10) -> List[str]: - return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)] - - def test_chunking_no_splitter() -> None: pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) info = pipe.run( @@ -477,24 +473,9 @@ def test_chunking_no_splitter() -> None: # TODO: Check and compare output -def test_chunking_with_splitter() -> None: - pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) - - info = pipe.run( - lancedb_adapter(docs[0], embed="text", splitter=splitter), - table_name="documents", - ) - assert_load_info(info) - - # TODO: Check and compare output - - -def test_chunk_merge() -> None: - """Test chunking is applied without orphaned chunks when new documents arrive.""" - +def test_chunk_merge_no_splitter() -> None: pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) - info = pipe.run( lancedb_adapter(docs[0], embed="text", splitter=splitter), table_name="documents", From 32fe174590f6e63edcfae4fde144855692ce2861 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 21 Jul 2024 22:42:58 +0200 Subject: [PATCH 007/113] Implement efficient update strategy for chunked documents in LanceDB Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 17 ++- tests/load/lancedb/test_pipeline.py | 144 +++++++++++------- tests/load/lancedb/utils.py | 25 ++- 3 files changed, 121 insertions(+), 65 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index e3025d19bf..05e327a763 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -141,7 +141,7 @@ def from_db_type( ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale) == self.capabilities.wei_precision: + if (precision, scale)==self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) @@ -183,9 +183,9 @@ def upload_batch( try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition == "replace": + elif write_disposition=="replace": tbl.add(records, mode="overwrite") - elif write_disposition == "merge": + elif write_disposition=="merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( @@ -206,7 +206,10 @@ def upload_batch( child_ids = set(pc.unique(tbl.to_arrow()["_dlt_parent_id"]).to_pylist()) if orphaned_ids := child_ids - parent_ids: - tbl.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") + if len(orphaned_ids) > 1: + tbl.delete(f"_dlt_parent_id IN {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}") + elif len(orphaned_ids) == 1: + tbl.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: raise DestinationTerminalException( @@ -251,7 +254,7 @@ def __init__( self.config.credentials.embedding_model_provider_api_key, ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider == "openai": + if embedding_model_provider=="openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings self.model_func = PatchedOpenAIEmbeddings( @@ -344,7 +347,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name != self.sentinel_table] + return [table_name for table_name in table_names if table_name!=self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -585,7 +588,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" ).sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows == 0: + if joined_table.num_rows==0: return None state = joined_table.take([0]).to_pylist()[0] diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index bc4af48c50..a0776559f8 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,17 +1,21 @@ from typing import Iterator, Generator, Any, List +from typing import Union, Dict import pytest +from lancedb.table import Table import dlt from dlt.common import json -from dlt.common.typing import DictStrStr, DictStrAny -from dlt.common.utils import uniq_id +from dlt.common.typing import DictStrAny +from dlt.common.typing import DictStrStr +from dlt.common.utils import uniq_id, digest128 from dlt.destinations.impl.lancedb.lancedb_adapter import ( lancedb_adapter, VECTORIZE_HINT, ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient -from tests.load.lancedb.utils import assert_table +from dlt.extract import DltResource +from tests.load.lancedb.utils import assert_table, chunk_document, mock_embed from tests.load.utils import sequence_generator, drop_active_pipeline_data from tests.pipeline.utils import assert_load_info @@ -426,7 +430,9 @@ def test_empty_dataset_allowed() -> None: client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None - info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) + info = pipe.run( + lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]) + ) # Dataset in load info is empty. assert info.dataset_name is None client = pipe.destination_client() # type: ignore[assignment] @@ -435,67 +441,97 @@ def test_empty_dataset_allowed() -> None: assert_table(pipe, "content", expected_items_count=3) -docs = [ - [ +def test_merge_no_orphans() -> None: + @dlt.resource( + write_disposition="merge", + merge_key=["doc_id", "chunk_hash"], + table_name="document", + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + chunks = chunk_document(doc["text"]) + embeddings = [ + { + "chunk_hash": digest128(chunk), + "chunk_text": chunk, + "embedding": mock_embed(), + } + for chunk in chunks + ] + yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} + + @dlt.source(max_table_nesting=1) + def documents_source( + docs: List[DictStrAny], + ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: + return documents(docs) + + pipeline = dlt.pipeline( + pipeline_name="chunked_docs", + destination="lancedb", + dataset_name="chunked_documents", + dev_mode=True, + ) + + initial_docs = [ { - "text": ( - "This is the first document. It contains some text that will be chunked and" - " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" - ), - "id": 1, + "text": "This is the first document. It contains some text that will be chunked and embedded. (I don't want " + "to be seen in updated run's embedding chunk texts btw)", + "doc_id": 1, }, { "text": "Here's another document. It's a bit different from the first one.", - "id": 2, + "doc_id": 2, }, - ], - [ + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ { "text": "This is the first document, but it has been updated with new content.", - "id": 1, + "doc_id": 1, }, { "text": "This is a completely new document that wasn't in the initial set.", - "id": 3, + "doc_id": 3, }, - ], -] - - -def test_chunking_no_splitter() -> None: - pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) - info = pipe.run( - docs[0], - table_name="documents", - ) - assert_load_info(info) - - # TODO: Check and compare output - - -def test_chunk_merge_no_splitter() -> None: - pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) - - info = pipe.run( - lancedb_adapter(docs[0], embed="text", splitter=splitter), - table_name="documents", - write_disposition="merge", - primary_key="id", - ) - pipe.run(info) + ] - # Orphaned chunks must be discarded. - info = pipe.run( - lancedb_adapter(docs[1], embed="text", splitter=splitter), - table_name="documents", - write_disposition="merge", - primary_key="id", - ) + info = pipeline.run(documents_source(updated_docs)) assert_load_info(info) - # TODO: Check and compare output - - -def test_embedding_provider_only_called_once_per_chunk_hash() -> None: - """Verify that the embedding provider is called only once for each unique chunk hash to optimize API usage and reduce costs.""" - raise NotImplementedError + with pipeline.destination_client() as client: + # Orphaned chunks/documents must have been discarded. + # Shouldn't contain any text from `initial_docs' where doc_id=1. + expected_text = { + "Here's ano", + "ther docum", + "ent. It's ", + "a bit diff", + "erent from", + " the first", + " one.", + "This is th", + "e first do", + "cument, bu", + "t it has b", + "een update", + "d with new", + " content.", + "This is a ", + "completely", + " new docum", + "ent that w", + "asn't in t", + "he initial", + " set.", + } + + embeddings_table_name = client.make_qualified_table_name("document__embeddings") # type: ignore[attr-defined] + + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + assert set(df["chunk_text"]) == expected_text diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index dc3ea5304b..2fdc8e5b40 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -22,8 +22,12 @@ def assert_unordered_dicts_equal( """ assert len(dict_list1) == len(dict_list2), "Lists have different length" - dict_set1 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1} - dict_set2 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2} + dict_set1 = { + tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1 + } + dict_set2 = { + tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2 + } assert dict_set1 == dict_set2, "Lists contain different dictionaries" @@ -40,7 +44,9 @@ def assert_table( exists = client.table_exists(qualified_table_name) assert exists - records = client.db_client.open_table(qualified_table_name).search().limit(50).to_list() + records = ( + client.db_client.open_table(qualified_table_name).search().limit(50).to_list() + ) if expected_items_count is not None: assert expected_items_count == len(records) @@ -52,7 +58,8 @@ def assert_table( "_dlt_id", "_dlt_load_id", dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", - dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) + or "vector__", ] objects_without_dlt_or_special_keys = [ {k: v for k, v in record.items() if k not in drop_keys} for record in records @@ -72,3 +79,13 @@ def generate_embeddings( def ndims(self) -> int: return 2 + + +def mock_embed( + dim: int = 10, +) -> str: + return str(np.random.random_sample(dim)) + + +def chunk_document(doc: str, chunk_size: int = 10) -> List[str]: + return [doc[i : i + chunk_size] for i in range(0, len(doc), chunk_size)] From d97496289d55bfdc3146e4dcef80ad2cadab7f49 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 21 Jul 2024 22:58:18 +0200 Subject: [PATCH 008/113] Refactor LanceDB client and tests for improved readability and type safety Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 17 ++++++++++------- tests/load/lancedb/test_pipeline.py | 12 ++++++------ tests/load/lancedb/utils.py | 15 ++++----------- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 05e327a763..981b062246 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -141,7 +141,7 @@ def from_db_type( ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale)==self.capabilities.wei_precision: + if (precision, scale) == self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) @@ -183,9 +183,9 @@ def upload_batch( try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition=="replace": + elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition=="merge": + elif write_disposition == "merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( @@ -207,7 +207,10 @@ def upload_batch( if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: - tbl.delete(f"_dlt_parent_id IN {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}") + tbl.delete( + "_dlt_parent_id IN" + f" {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}" + ) elif len(orphaned_ids) == 1: tbl.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") @@ -254,7 +257,7 @@ def __init__( self.config.credentials.embedding_model_provider_api_key, ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider=="openai": + if embedding_model_provider == "openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings self.model_func = PatchedOpenAIEmbeddings( @@ -347,7 +350,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name!=self.sentinel_table] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -588,7 +591,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" ).sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows==0: + if joined_table.num_rows == 0: return None state = joined_table.take([0]).to_pylist()[0] diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index a0776559f8..74d97cb88f 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -2,7 +2,7 @@ from typing import Union, Dict import pytest -from lancedb.table import Table +from lancedb.table import Table # type: ignore[import-untyped] import dlt from dlt.common import json @@ -430,9 +430,7 @@ def test_empty_dataset_allowed() -> None: client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None - info = pipe.run( - lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]) - ) + info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) # Dataset in load info is empty. assert info.dataset_name is None client = pipe.destination_client() # type: ignore[assignment] @@ -476,8 +474,10 @@ def documents_source( initial_docs = [ { - "text": "This is the first document. It contains some text that will be chunked and embedded. (I don't want " - "to be seen in updated run's embedding chunk texts btw)", + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), "doc_id": 1, }, { diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index 2fdc8e5b40..8dd56d22aa 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -22,12 +22,8 @@ def assert_unordered_dicts_equal( """ assert len(dict_list1) == len(dict_list2), "Lists have different length" - dict_set1 = { - tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1 - } - dict_set2 = { - tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2 - } + dict_set1 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1} + dict_set2 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2} assert dict_set1 == dict_set2, "Lists contain different dictionaries" @@ -44,9 +40,7 @@ def assert_table( exists = client.table_exists(qualified_table_name) assert exists - records = ( - client.db_client.open_table(qualified_table_name).search().limit(50).to_list() - ) + records = client.db_client.open_table(qualified_table_name).search().limit(50).to_list() if expected_items_count is not None: assert expected_items_count == len(records) @@ -58,8 +52,7 @@ def assert_table( "_dlt_id", "_dlt_load_id", dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", - dlt.config.get("destination.lancedb.credentials.vector_field_name", str) - or "vector__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", ] objects_without_dlt_or_special_keys = [ {k: v for k, v in record.items() if k not in drop_keys} for record in records From e6cdf5d181343c78090ef86c8cbceb6347ab7c69 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 21 Jul 2024 23:20:06 +0200 Subject: [PATCH 009/113] Linting Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 74d97cb88f..e312a1e591 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -2,7 +2,7 @@ from typing import Union, Dict import pytest -from lancedb.table import Table # type: ignore[import-untyped] +from lancedb.table import Table # type: ignore import dlt from dlt.common import json From a60737aeaa1dc84f6735c7e62afe4ca6d6ac5168 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 27 Jul 2024 17:13:27 +0200 Subject: [PATCH 010/113] Add document_id parameter to lancedb_adapter and update merge logic Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_adapter.py | 20 +++- dlt/destinations/impl/lancedb/utils.py | 4 +- tests/load/lancedb/test_pipeline.py | 113 +++++++++++++++++- 3 files changed, 134 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index bb33632b48..faa2ab3399 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -6,11 +6,13 @@ VECTORIZE_HINT = "x-lancedb-embed" +DOCUMENT_ID_HINT = "x-lancedb-doc-id" def lancedb_adapter( data: Any, embed: TColumnNames = None, + document_id: TColumnNames = None, ) -> DltResource: """Prepares data for the LanceDB destination by specifying which columns should be embedded. @@ -20,6 +22,8 @@ def lancedb_adapter( object. embed (TColumnNames, optional): Specify columns to generate embeddings for. It can be a single column name as a string, or a list of column names. + document_id (TColumnNames, optional): Specify columns which represenet the document + and which will be appended to primary/merge keys. Returns: DltResource: A resource with applied LanceDB-specific hints. @@ -50,8 +54,22 @@ def lancedb_adapter( VECTORIZE_HINT: True, # type: ignore[misc] } + if document_id: + if isinstance(document_id, str): + embed = [document_id] + if not isinstance(document_id, list): + raise ValueError( + "'document_id' must be a list of column names or a single column name as a string." + ) + + for column_name in document_id: + column_hints[column_name] = { + "name": column_name, + DOCUMENT_ID_HINT: True, # type: ignore[misc] + } + if not column_hints: - raise ValueError("A value for 'embed' must be specified.") + raise ValueError("At least one of 'embed' or 'document_id' must be specified.") else: resource.apply_hints(columns=column_hints) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index aeacd4d34b..874852512a 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -6,6 +6,7 @@ from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider +from dlt.destinations.impl.lancedb.lancedb_adapter import DOCUMENT_ID_HINT PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -41,9 +42,10 @@ def list_merge_identifiers(table_schema: TTableSchema) -> Sequence[str]: Sequence[str]: A list of unique column identifiers. """ if table_schema.get("write_disposition") == "merge": + document_id = get_columns_names_with_prop(table_schema, DOCUMENT_ID_HINT) primary_keys = get_columns_names_with_prop(table_schema, "primary_key") merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys)): + if join_keys := list(set(primary_keys + merge_keys + document_id)): return join_keys return get_columns_names_with_prop(table_schema, "unique") diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index e312a1e591..ea52e2f472 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -12,6 +12,7 @@ from dlt.destinations.impl.lancedb.lancedb_adapter import ( lancedb_adapter, VECTORIZE_HINT, + DOCUMENT_ID_HINT, ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient from dlt.extract import DltResource @@ -50,6 +51,18 @@ def some_data() -> Generator[DictStrStr, Any, None]: "x-lancedb-embed": True, } + lancedb_adapter( + some_data, + document_id=["content"], + ) + + assert some_data.columns["content"] == { # type: ignore + "name": "content", + "data_type": "text", + "x-lancedb-embed": True, + "x-lancedb-doc-id": True, + } + def test_basic_state_and_schema() -> None: generator_instance1 = sequence_generator() @@ -442,8 +455,106 @@ def test_empty_dataset_allowed() -> None: def test_merge_no_orphans() -> None: @dlt.resource( write_disposition="merge", - merge_key=["doc_id", "chunk_hash"], + primary_key=["doc_id"], + table_name="document", + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + chunks = chunk_document(doc["text"]) + embeddings = [ + { + "chunk_hash": digest128(chunk), + "chunk_text": chunk, + "embedding": mock_embed(), + } + for chunk in chunks + ] + yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} + + @dlt.source(max_table_nesting=1) + def documents_source( + docs: List[DictStrAny], + ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: + return documents(docs) + + pipeline = dlt.pipeline( + pipeline_name="chunked_docs", + destination="lancedb", + dataset_name="chunked_documents", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ + { + "text": "This is the first document, but it has been updated with new content.", + "doc_id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "doc_id": 3, + }, + ] + + info = pipeline.run(documents_source(updated_docs)) + assert_load_info(info) + + with pipeline.destination_client() as client: + # Orphaned chunks/documents must have been discarded. + # Shouldn't contain any text from `initial_docs' where doc_id=1. + expected_text = { + "Here's ano", + "ther docum", + "ent. It's ", + "a bit diff", + "erent from", + " the first", + " one.", + "This is th", + "e first do", + "cument, bu", + "t it has b", + "een update", + "d with new", + " content.", + "This is a ", + "completely", + " new docum", + "ent that w", + "asn't in t", + "he initial", + " set.", + } + + embeddings_table_name = client.make_qualified_table_name("document__embeddings") # type: ignore[attr-defined] + + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + assert set(df["chunk_text"]) == expected_text + + +def test_merge_no_orphans_with_doc_id() -> None: + @dlt.resource( # type: ignore + write_disposition="merge", table_name="document", + columns={"doc_id": {DOCUMENT_ID_HINT: True}}, ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: From 518a5078ad068a51627888d615e5090b2cf82501 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 29 Jul 2024 16:13:50 +0200 Subject: [PATCH 011/113] Remove resolved comments Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 122 +++++++++++++----- 1 file changed, 89 insertions(+), 33 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 981b062246..09c1cca372 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -78,7 +78,9 @@ NDArray = ndarray TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = { + v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items() +} class LanceDBTypeMapper(TypeMapper): @@ -187,7 +189,9 @@ def upload_batch( tbl.add(records, mode="overwrite") elif write_disposition == "merge": if not id_field_name: - raise ValueError("To perform a merge update, 'id_field_name' must be specified.") + raise ValueError( + "To perform a merge update, 'id_field_name' must be specified." + ) tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) @@ -202,7 +206,9 @@ def upload_batch( "Couldn't open lancedb database. Batch WILL BE RETRIED" ) from e - parent_ids = set(pc.unique(parent_tbl.to_arrow()["_dlt_id"]).to_pylist()) + parent_ids = set( + pc.unique(parent_tbl.to_arrow()["_dlt_id"]).to_pylist() + ) child_ids = set(pc.unique(tbl.to_arrow()["_dlt_parent_id"]).to_pylist()) if orphaned_ids := child_ids - parent_ids: @@ -299,7 +305,9 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + def create_table( + self, table_name: str, schema: TArrowSchema, mode: str = "create" + ) -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -322,7 +330,9 @@ def delete_table(self, table_name: str) -> None: def query_table( self, table_name: str, - query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, + query: Union[ + List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None + ] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. @@ -350,7 +360,11 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name != self.sentinel_table] + return [ + table_name + for table_name in table_names + if table_name != self.sentinel_table + ] @lancedb_error def drop_storage(self) -> None: @@ -403,7 +417,9 @@ def update_stored_schema( applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + schema_info = self.get_stored_schema_by_hash( + self.schema.stored_version_hash + ) except DestinationUndefinedEntity: schema_info = None @@ -458,19 +474,25 @@ def add_table_fields( # Check if any of the new fields already exist in the table. existing_fields = set(arrow_table.schema.names) - new_fields = [field for field in field_schemas if field.name not in existing_fields] + new_fields = [ + field for field in field_schemas if field.name not in existing_fields + ] if not new_fields: # All fields already present, skip. return None - null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + null_arrays = [ + pa.nulls(len(arrow_table), type=field.type) for field in new_fields + ] for field, null_array in zip(new_fields, null_arrays): arrow_table = arrow_table.append_column(field, null_array) try: - return self.db_client.create_table(table_name, arrow_table, mode="overwrite") + return self.db_client.create_table( + table_name, arrow_table, mode="overwrite" + ) except OSError: # Error occurred while creating the table, skip. return None @@ -483,11 +505,15 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") + logger.info( + f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" + ) if len(new_columns) > 0: if exists: field_schemas: List[TArrowField] = [ - make_arrow_field_schema(column["name"], column, self.type_mapper) + make_arrow_field_schema( + column["name"], column, self.type_mapper + ) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) @@ -500,7 +526,9 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func - embedding_model_dimensions = self.config.embedding_model_dimensions + embedding_model_dimensions = ( + self.config.embedding_model_dimensions + ) else: embedding_fields = None vector_field_name = None @@ -531,8 +559,12 @@ def update_schema_in_storage(self) -> None: self.schema.naming.normalize_identifier( "engine_version" ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("inserted_at"): str( + pendulum.now() + ), + self.schema.naming.normalize_identifier( + "schema_name" + ): self.schema.name, self.schema.naming.normalize_identifier( "version_hash" ): self.schema.stored_version_hash, @@ -541,7 +573,9 @@ def update_schema_in_storage(self) -> None: ), } ] - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) @@ -557,8 +591,12 @@ def update_schema_in_storage(self) -> None: @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Retrieves the latest completed state for a pipeline.""" - fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_state_table_name = self.make_qualified_table_name( + self.schema.state_table_name + ) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) state_table_: Table = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() @@ -584,7 +622,9 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) .to_arrow() ) - loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + loads_table = ( + loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + ) # Join arrow tables in-memory. joined_table: pa.Table = state_table.join( @@ -606,8 +646,12 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) @lancedb_error - def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + def get_stored_schema_by_hash( + self, schema_hash: str + ) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -625,9 +669,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI ) ).to_list() - # LanceDB's ORDER BY clause doesn't seem to work. - # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 - most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + most_recent_schema = sorted( + schemas, key=lambda x: x[p_inserted_at], reverse=True + )[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -642,7 +686,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI @lancedb_error def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -660,9 +706,9 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: ) ).to_list() - # LanceDB's ORDER BY clause doesn't seem to work. - # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 - most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + most_recent_schema = sorted( + schemas, key=lambda x: x[p_inserted_at], reverse=True + )[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -690,15 +736,21 @@ def complete_load(self, load_id: str) -> None: records = [ { self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier( + "schema_name" + ): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("inserted_at"): str( + pendulum.now() + ), self.schema.naming.normalize_identifier( "schema_version_hash" ): None, # Payload schema must match the target schema. } ] - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) @@ -712,7 +764,9 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: parent_table = table.get("parent") return LoadLanceDBJob( @@ -758,7 +812,9 @@ def __init__( self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) + self.embedding_fields: List[str] = get_columns_names_with_prop( + table_schema, VECTORIZE_HINT + ) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name From c10bd734f8742b4b6a910c18d52c10f9279cedcc Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 29 Jul 2024 17:03:33 +0200 Subject: [PATCH 012/113] Implement efficient orphan removal for chunked documents in LanceDB Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 104 +++++++++++++----- 1 file changed, 74 insertions(+), 30 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 09c1cca372..4e3d5c7488 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -38,6 +38,7 @@ StorageSchemaInfo, StateInfo, TLoadJobState, + FollowupJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -155,7 +156,6 @@ def upload_batch( *, db_client: DBConnection, table_name: str, - parent_table_name: Optional[str] = None, write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, ) -> None: @@ -165,7 +165,6 @@ def upload_batch( records: The data to be inserted as payload. db_client: The LanceDB client connection. table_name: The name of the table to insert into. - parent_table_name: The name of the parent table, if the target table has any. id_field_name: The name of the ID field for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. @@ -187,7 +186,7 @@ def upload_batch( tbl.add(records) elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition == "merge": + elif write_disposition in ("merge" or "upsert"): if not id_field_name: raise ValueError( "To perform a merge update, 'id_field_name' must be specified." @@ -195,31 +194,6 @@ def upload_batch( tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) - - # Remove orphaned parent IDs. - if parent_table_name: - try: - parent_tbl = db_client.open_table(parent_table_name) - parent_tbl.checkout_latest() - except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Batch WILL BE RETRIED" - ) from e - - parent_ids = set( - pc.unique(parent_tbl.to_arrow()["_dlt_id"]).to_pylist() - ) - child_ids = set(pc.unique(tbl.to_arrow()["_dlt_parent_id"]).to_pylist()) - - if orphaned_ids := child_ids - parent_ids: - if len(orphaned_ids) > 1: - tbl.delete( - "_dlt_parent_id IN" - f" {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}" - ) - elif len(orphaned_ids) == 1: - tbl.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") - else: raise DestinationTerminalException( f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" @@ -584,7 +558,6 @@ def update_schema_in_storage(self) -> None: records, db_client=self.db_client, table_name=fq_version_table_name, - parent_table_name=None, write_disposition=write_disposition, ) @@ -845,7 +818,6 @@ def __init__( records, db_client=db_client, table_name=self.fq_table_name, - parent_table_name=self.fq_parent_table_name, write_disposition=self.write_disposition, id_field_name=self.id_field_name, ) @@ -855,3 +827,75 @@ def state(self) -> TLoadJobState: def exception(self) -> str: raise NotImplementedError() + + +class LanceDBRemoveOrphansJob(LoadJob, FollowupJob): + def __init__( + self, + db_client: DBConnection, + file_name: str, + table_schema: TTableSchema, + fq_table_name: str, + fq_parent_table_name: Optional[str], + ) -> None: + super().__init__(file_name) + self.db_client = db_client + self._state: TLoadJobState = "running" + self.table_schema: TTableSchema = table_schema + self.fq_table_name: str = fq_table_name + self.fq_parent_table_name: Optional[str] = fq_parent_table_name + self.write_disposition: TWriteDisposition = cast( + TWriteDisposition, self.table_schema.get("write_disposition", "append") + ) + + def execute(self) -> None: + if self.write_disposition not in ("merge" or "upsert"): + raise DestinationTerminalException( + f"Unsupported write disposition {self.write_disposition} for LanceDB Destination Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." + ) + + try: + child_table = self.db_client.open_table(self.fq_table_name) + child_table.checkout_latest() + if self.fq_parent_table_name: + parent_table = self.db_client.open_table(self.fq_parent_table_name) + parent_table.checkout_latest() + except FileNotFoundError as e: + raise DestinationTransientException( + "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" + ) from e + + try: + # Chunks in child table. + if self.fq_parent_table_name: + parent_ids = set( + pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist() + ) + child_ids = set( + pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist() + ) + + if orphaned_ids := child_ids - parent_ids: + if len(orphaned_ids) > 1: + child_table.delete( + "_dlt_parent_id IN" + f" {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}" + ) + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") + + # Chunks in root table. TODO: Add test for embeddings in root table + else: + ... + except ArrowInvalid as e: + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e + + self._state = "completed" + + def state(self) -> TLoadJobState: + return self._state + + def exception(self) -> str: + raise NotImplementedError() From 5b3acb167cec0fa863124191490766fdfe5c4197 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 19:30:16 +0200 Subject: [PATCH 013/113] Implement efficient update strategy for chunked documents in LanceDB Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 57 ++++++++++++++----- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 4e3d5c7488..a8ea9f6f9d 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -38,7 +38,7 @@ StorageSchemaInfo, StateInfo, TLoadJobState, - FollowupJob, + NewLoadJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -49,7 +49,7 @@ TWriteDisposition, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage +from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -70,7 +70,7 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -756,6 +756,33 @@ def start_file_load( ), ) + def create_table_chain_completed_followup_jobs( + self, + table_chain: Sequence[TTableSchema], + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + ) -> List[NewLoadJob]: + assert completed_table_chain_jobs is not None + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) + + for table in table_chain: + parent_table = table.get("parent") + jobs.append( + LanceDBRemoveOrphansJob( + db_client=self.db_client, + table_schema=self.prepare_load_table(table["name"]), + fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) + if parent_table + else None + ), + ) + ) + + return jobs + def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -829,18 +856,24 @@ def exception(self) -> str: raise NotImplementedError() -class LanceDBRemoveOrphansJob(LoadJob, FollowupJob): +class LanceDBRemoveOrphansJob(NewReferenceJob): def __init__( self, db_client: DBConnection, - file_name: str, table_schema: TTableSchema, fq_table_name: str, fq_parent_table_name: Optional[str], ) -> None: - super().__init__(file_name) self.db_client = db_client - self._state: TLoadJobState = "running" + + ref_file_name = ParsedLoadJobFileName( + table_schema["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" + ).file_name() + super().__init__( + file_name=ref_file_name, + status="running", + ) + self.table_schema: TTableSchema = table_schema self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name @@ -848,6 +881,8 @@ def __init__( TWriteDisposition, self.table_schema.get("write_disposition", "append") ) + self.execute() + def execute(self) -> None: if self.write_disposition not in ("merge" or "upsert"): raise DestinationTerminalException( @@ -885,6 +920,7 @@ def execute(self) -> None: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") # Chunks in root table. TODO: Add test for embeddings in root table + # TODO: Add unit tests with simple data else: ... except ArrowInvalid as e: @@ -892,10 +928,5 @@ def execute(self) -> None: "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e - self._state = "completed" - def state(self) -> TLoadJobState: - return self._state - - def exception(self) -> str: - raise NotImplementedError() + return "completed" From cf6d86a965d9323da1dfe7b7144cacf640ef2de0 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 21:37:47 +0200 Subject: [PATCH 014/113] Add test for removing orphaned records in LanceDB Signed-off-by: Marcel Coetzee --- .../lancedb/test_remove_orphaned_records.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/load/lancedb/test_remove_orphaned_records.py diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py new file mode 100644 index 0000000000..9f5e3b288c --- /dev/null +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -0,0 +1,84 @@ +from typing import Iterator, List, Generator + +import pytest + +import dlt +from dlt.common.schema.typing import TLoaderMergeStrategy +from dlt.common.typing import DictStrAny +from dlt.common.utils import uniq_id +from tests.load.utils import ( + drop_active_pipeline_data, +) +from tests.pipeline.utils import ( + assert_load_info, + load_table_counts, + load_tables_to_dicts, +) + + +# Mark all tests as essential, don't remove. +pytestmark = pytest.mark.essential + + +@pytest.fixture(autouse=True) +def drop_lancedb_data() -> Iterator[None]: + yield + drop_active_pipeline_data() + + +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_lancedb_remove_orphaned_records( + merge_strategy: TLoaderMergeStrategy, +) -> None: + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + ) + + @dlt.resource( + table_name="parent", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, + primary_key="id", + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + run_1 = [ + {"id": 1, "child": [{"bar": 1}, {"bar": 2}]}, + {"id": 2, "child": [{"bar": 3}]}, + {"id": 3, "child": [{"bar": 10}, {"bar": 11}]}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + counts = load_table_counts(pipeline, "parent", "parent__child") + assert counts["parent"] == 2 + assert counts["parent__child"] == 4 + + run_2 = [ + {"id": 1, "child": [{"bar": 1}]}, # Removed one child. + {"id": 2, "child": [{"bar": 4}, {"baz": 1}]}, # Changed child. + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + # Check whether orphaned child records were removed. + counts = load_table_counts(pipeline, "parent", "parent__child") + assert counts["parent"] == 2 + assert counts["parent__child"] == 2 + + child_data = load_tables_to_dicts(pipeline, "parent__child") + expected_child_data = [ + {"bar": 1}, + {"bar": 4}, + {"baz": 1}, + {"bar": 10}, + {"bar": 11}, + ] + assert ( + sorted(child_data["parent__child"], key=lambda x: x["bar"]) + == expected_child_data + ) From d338586b115ef6e5fbd1a4578fe6d374b42be714 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 22:41:48 +0200 Subject: [PATCH 015/113] Update LanceDB orphaned records removal test for chunked documents Signed-off-by: Marcel Coetzee --- .../lancedb/test_remove_orphaned_records.py | 36 +++++++------------ 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 9f5e3b288c..5e19efba09 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -11,8 +11,6 @@ ) from tests.pipeline.utils import ( assert_load_info, - load_table_counts, - load_tables_to_dicts, ) @@ -54,31 +52,23 @@ def identity_resource( info = pipeline.run(identity_resource(run_1)) assert_load_info(info) - counts = load_table_counts(pipeline, "parent", "parent__child") - assert counts["parent"] == 2 - assert counts["parent__child"] == 4 - run_2 = [ {"id": 1, "child": [{"bar": 1}]}, # Removed one child. - {"id": 2, "child": [{"bar": 4}, {"baz": 1}]}, # Changed child. + {"id": 2, "child": [{"bar": 4}]}, # Changed child. ] info = pipeline.run(identity_resource(run_2)) assert_load_info(info) - # Check whether orphaned child records were removed. - counts = load_table_counts(pipeline, "parent", "parent__child") - assert counts["parent"] == 2 - assert counts["parent__child"] == 2 + with pipeline.destination_client() as client: + expected_child_data = [ + 1, + 4, + 10, + 11, + ] - child_data = load_tables_to_dicts(pipeline, "parent__child") - expected_child_data = [ - {"bar": 1}, - {"bar": 4}, - {"baz": 1}, - {"bar": 10}, - {"bar": 11}, - ] - assert ( - sorted(child_data["parent__child"], key=lambda x: x["bar"]) - == expected_child_data - ) + embeddings_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] + + tbl = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + assert sorted(df["bar"].to_list()) == expected_child_data From 2376c6a445b67d8949285bacdd9e4f619e5af463 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 22:49:22 +0200 Subject: [PATCH 016/113] Set test pipeline as dev mode Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_remove_orphaned_records.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 5e19efba09..f367d31807 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -32,6 +32,7 @@ def test_lancedb_remove_orphaned_records( pipeline_name="test_pipeline_append", destination="lancedb", dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + dev_mode=True, ) @dlt.resource( From 7f6f1cd0e0b3eb9a3cd1b6fcc3d46f9a63e75c4d Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 23:56:16 +0200 Subject: [PATCH 017/113] Fix write disposition check in LanceDBRemoveOrphansJob execute method Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a8ea9f6f9d..a6dad09b60 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -186,7 +186,7 @@ def upload_batch( tbl.add(records) elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition in ("merge" or "upsert"): + elif write_disposition == "merge": if not id_field_name: raise ValueError( "To perform a merge update, 'id_field_name' must be specified." @@ -884,7 +884,7 @@ def __init__( self.execute() def execute(self) -> None: - if self.write_disposition not in ("merge" or "upsert"): + if self.write_disposition != "merge": raise DestinationTerminalException( f"Unsupported write disposition {self.write_disposition} for LanceDB Destination Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." ) From c276211b91cfc3735f0d0a36a83e0715ab4c3313 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 16:07:14 +0200 Subject: [PATCH 018/113] Add FollowupJob trait to LoadLanceDBJob Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a6dad09b60..5beff057dd 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -39,6 +39,7 @@ StateInfo, TLoadJobState, NewLoadJob, + FollowupJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -70,7 +71,7 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob +from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob, NewLoadJobImpl from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -787,7 +788,7 @@ def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LoadLanceDBJob(LoadJob): +class LoadLanceDBJob(LoadJob, FollowupJob): arrow_schema: TArrowSchema def __init__( @@ -856,7 +857,7 @@ def exception(self) -> str: raise NotImplementedError() -class LanceDBRemoveOrphansJob(NewReferenceJob): +class LanceDBRemoveOrphansJob(NewLoadJobImpl): def __init__( self, db_client: DBConnection, @@ -920,7 +921,6 @@ def execute(self) -> None: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") # Chunks in root table. TODO: Add test for embeddings in root table - # TODO: Add unit tests with simple data else: ... except ArrowInvalid as e: From dbfd5af9c28bed805a5364d9fa954d2e8abf6210 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 17:48:04 +0200 Subject: [PATCH 019/113] Fix file type Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 52 +++++++++++-------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 5beff057dd..cb8935e672 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -71,7 +71,7 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob, NewLoadJobImpl +from dlt.destinations.job_impl import EmptyLoadJob, NewLoadJobImpl from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -768,19 +768,21 @@ def create_table_chain_completed_followup_jobs( ) for table in table_chain: - parent_table = table.get("parent") - jobs.append( - LanceDBRemoveOrphansJob( - db_client=self.db_client, - table_schema=self.prepare_load_table(table["name"]), - fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) - if parent_table - else None - ), + # Only tables with merge disposition are dispatched for orphan removal jobs. + if table.get("write_disposition", "append") == "merge": + parent_table = table.get("parent") + jobs.append( + LanceDBRemoveOrphansJob( + db_client=self.db_client, + table_schema=self.prepare_load_table(table["name"]), + fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) + if parent_table + else None + ), + ) ) - ) return jobs @@ -866,15 +868,6 @@ def __init__( fq_parent_table_name: Optional[str], ) -> None: self.db_client = db_client - - ref_file_name = ParsedLoadJobFileName( - table_schema["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" - ).file_name() - super().__init__( - file_name=ref_file_name, - status="running", - ) - self.table_schema: TTableSchema = table_schema self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name @@ -882,6 +875,20 @@ def __init__( TWriteDisposition, self.table_schema.get("write_disposition", "append") ) + job_id = ParsedLoadJobFileName( + table_schema["name"], + ParsedLoadJobFileName.new_file_id(), + 0, + "parquet", + ).file_name() + + super().__init__( + file_name=job_id, + status="running", + ) + + self._save_text_file("") + self.execute() def execute(self) -> None: @@ -890,6 +897,7 @@ def execute(self) -> None: f"Unsupported write disposition {self.write_disposition} for LanceDB Destination Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." ) + # Orphans are removed irrespective of which merge strategy is picked. try: child_table = self.db_client.open_table(self.fq_table_name) child_table.checkout_latest() From 257fbde1c809211e4dfce4962f40807b59ad71b6 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 17:54:55 +0200 Subject: [PATCH 020/113] Fix file typing Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 136 +++++------------- .../lancedb/test_remove_orphaned_records.py | 8 +- 2 files changed, 38 insertions(+), 106 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index cb8935e672..40da46f199 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -80,9 +80,7 @@ NDArray = ndarray TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = { - v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items() -} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} class LanceDBTypeMapper(TypeMapper): @@ -189,9 +187,7 @@ def upload_batch( tbl.add(records, mode="overwrite") elif write_disposition == "merge": if not id_field_name: - raise ValueError( - "To perform a merge update, 'id_field_name' must be specified." - ) + raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) @@ -280,9 +276,7 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table( - self, table_name: str, schema: TArrowSchema, mode: str = "create" - ) -> Table: + def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -305,9 +299,7 @@ def delete_table(self, table_name: str) -> None: def query_table( self, table_name: str, - query: Union[ - List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None - ] = None, + query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. @@ -335,11 +327,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [ - table_name - for table_name in table_names - if table_name != self.sentinel_table - ] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -392,9 +380,7 @@ def update_stored_schema( applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash( - self.schema.stored_version_hash - ) + schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) except DestinationUndefinedEntity: schema_info = None @@ -449,25 +435,19 @@ def add_table_fields( # Check if any of the new fields already exist in the table. existing_fields = set(arrow_table.schema.names) - new_fields = [ - field for field in field_schemas if field.name not in existing_fields - ] + new_fields = [field for field in field_schemas if field.name not in existing_fields] if not new_fields: # All fields already present, skip. return None - null_arrays = [ - pa.nulls(len(arrow_table), type=field.type) for field in new_fields - ] + null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] for field, null_array in zip(new_fields, null_arrays): arrow_table = arrow_table.append_column(field, null_array) try: - return self.db_client.create_table( - table_name, arrow_table, mode="overwrite" - ) + return self.db_client.create_table(table_name, arrow_table, mode="overwrite") except OSError: # Error occurred while creating the table, skip. return None @@ -480,15 +460,11 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - logger.info( - f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" - ) + logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: field_schemas: List[TArrowField] = [ - make_arrow_field_schema( - column["name"], column, self.type_mapper - ) + make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) @@ -501,9 +477,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func - embedding_model_dimensions = ( - self.config.embedding_model_dimensions - ) + embedding_model_dimensions = self.config.embedding_model_dimensions else: embedding_fields = None vector_field_name = None @@ -534,12 +508,8 @@ def update_schema_in_storage(self) -> None: self.schema.naming.normalize_identifier( "engine_version" ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str( - pendulum.now() - ), - self.schema.naming.normalize_identifier( - "schema_name" - ): self.schema.name, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier( "version_hash" ): self.schema.stored_version_hash, @@ -548,9 +518,7 @@ def update_schema_in_storage(self) -> None: ), } ] - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) @@ -565,12 +533,8 @@ def update_schema_in_storage(self) -> None: @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Retrieves the latest completed state for a pipeline.""" - fq_state_table_name = self.make_qualified_table_name( - self.schema.state_table_name - ) - fq_loads_table_name = self.make_qualified_table_name( - self.schema.loads_table_name - ) + fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) state_table_: Table = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() @@ -596,9 +560,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) .to_arrow() ) - loads_table = ( - loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() - ) + loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() # Join arrow tables in-memory. joined_table: pa.Table = state_table.join( @@ -620,12 +582,8 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) @lancedb_error - def get_stored_schema_by_hash( - self, schema_hash: str - ) -> Optional[StorageSchemaInfo]: - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -643,9 +601,7 @@ def get_stored_schema_by_hash( ) ).to_list() - most_recent_schema = sorted( - schemas, key=lambda x: x[p_inserted_at], reverse=True - )[0] + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -660,9 +616,7 @@ def get_stored_schema_by_hash( @lancedb_error def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -680,9 +634,7 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: ) ).to_list() - most_recent_schema = sorted( - schemas, key=lambda x: x[p_inserted_at], reverse=True - )[0] + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -710,21 +662,15 @@ def complete_load(self, load_id: str) -> None: records = [ { self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier( - "schema_name" - ): self.schema.name, + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str( - pendulum.now() - ), + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), self.schema.naming.normalize_identifier( "schema_version_hash" ): None, # Payload schema must match the target schema. } ] - fq_loads_table_name = self.make_qualified_table_name( - self.schema.loads_table_name - ) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) @@ -738,9 +684,7 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load( - self, table: TTableSchema, file_path: str, load_id: str - ) -> LoadJob: + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: parent_table = table.get("parent") return LoadLanceDBJob( @@ -769,7 +713,7 @@ def create_table_chain_completed_followup_jobs( for table in table_chain: # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition", "append") == "merge": + if table.get("write_disposition") == "merge": parent_table = table.get("parent") jobs.append( LanceDBRemoveOrphansJob( @@ -777,9 +721,7 @@ def create_table_chain_completed_followup_jobs( table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name(table["name"]), fq_parent_table_name=( - self.make_qualified_table_name(parent_table) - if parent_table - else None + self.make_qualified_table_name(parent_table) if parent_table else None ), ) ) @@ -815,9 +757,7 @@ def __init__( self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop( - table_schema, VECTORIZE_HINT - ) + self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name @@ -872,7 +812,7 @@ def __init__( self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition", "append") + TWriteDisposition, self.table_schema.get("write_disposition") ) job_id = ParsedLoadJobFileName( @@ -894,7 +834,8 @@ def __init__( def execute(self) -> None: if self.write_disposition != "merge": raise DestinationTerminalException( - f"Unsupported write disposition {self.write_disposition} for LanceDB Destination Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." + f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" + " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." ) # Orphans are removed irrespective of which merge strategy is picked. @@ -912,12 +853,8 @@ def execute(self) -> None: try: # Chunks in child table. if self.fq_parent_table_name: - parent_ids = set( - pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist() - ) - child_ids = set( - pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist() - ) + parent_ids = set(pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist()) + child_ids = set(pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist()) if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: @@ -929,8 +866,7 @@ def execute(self) -> None: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") # Chunks in root table. TODO: Add test for embeddings in root table - else: - ... + else: ... except ArrowInvalid as e: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index f367d31807..c5b60a5f3f 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -3,7 +3,6 @@ import pytest import dlt -from dlt.common.schema.typing import TLoaderMergeStrategy from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id from tests.load.utils import ( @@ -24,10 +23,7 @@ def drop_lancedb_data() -> Iterator[None]: drop_active_pipeline_data() -@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) -def test_lancedb_remove_orphaned_records( - merge_strategy: TLoaderMergeStrategy, -) -> None: +def test_lancedb_remove_orphaned_records() -> None: pipeline = dlt.pipeline( pipeline_name="test_pipeline_append", destination="lancedb", @@ -37,7 +33,7 @@ def test_lancedb_remove_orphaned_records( @dlt.resource( table_name="parent", - write_disposition={"disposition": "merge", "strategy": merge_strategy}, + write_disposition="merge", primary_key="id", ) def identity_resource( From 0502ddf01c7cf3b190f9352b1f0fbe8d2e59f13d Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 20:25:55 +0200 Subject: [PATCH 021/113] Add test for removing orphaned records in LanceDB root table Signed-off-by: Marcel Coetzee --- .../lancedb/test_remove_orphaned_records.py | 89 ++++++++++++++++--- 1 file changed, 79 insertions(+), 10 deletions(-) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index c5b60a5f3f..01481d83e3 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -1,6 +1,9 @@ from typing import Iterator, List, Generator +import pandas as pd import pytest +from pandas import DataFrame +from pandas.testing import assert_frame_equal import dlt from dlt.common.typing import DictStrAny @@ -25,9 +28,9 @@ def drop_lancedb_data() -> Iterator[None]: def test_lancedb_remove_orphaned_records() -> None: pipeline = dlt.pipeline( - pipeline_name="test_pipeline_append", + pipeline_name="test_lancedb_remove_orphaned_records", destination="lancedb", - dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + dataset_name=f"test_lancedb_remove_orphaned_records_{uniq_id()}", dev_mode=True, ) @@ -57,15 +60,81 @@ def identity_resource( assert_load_info(info) with pipeline.destination_client() as client: - expected_child_data = [ - 1, - 4, - 10, - 11, - ] + expected_child_data = pd.DataFrame( + data=[ + {"bar": 1}, + {"bar": 4}, + {"bar": 10}, + {"bar": 11}, + ] + ) embeddings_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] tbl = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] - df = tbl.to_pandas() - assert sorted(df["bar"].to_list()) == expected_child_data + actual_df = tbl.to_pandas() + + expected_child_data = expected_child_data.sort_values(by="bar") + actual_df = actual_df.sort_values(by="bar").reset_index(drop=True) + + assert_frame_equal(actual_df[["bar"]], expected_child_data) + + +def test_lancedb_remove_orphaned_records_root_table() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_root_table", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition="merge", + primary_key="doc_id", + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + run_1 = [ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2a"}, + {"doc_id": 2, "chunk_hash": "2b"}, + {"doc_id": 2, "chunk_hash": "2c"}, + {"doc_id": 3, "chunk_hash": "3a"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + ).sort_values(by=["doc_id", "chunk_hash"]) + + root_table_name = ( + client.make_qualified_table_name("root") # type: ignore[attr-defined] + ) + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas() + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) + assert_frame_equal(actual_root_df, expected_root_table_df) From 2363b51a38c085f6e9f81b1002a2e79021cb0724 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 20:48:16 +0200 Subject: [PATCH 022/113] Enhance LanceDB test to cover nested child removal and update scenarios Signed-off-by: Marcel Coetzee --- .../lancedb/test_remove_orphaned_records.py | 93 +++++++++++++++---- 1 file changed, 74 insertions(+), 19 deletions(-) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 01481d83e3..9104399c9b 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -45,16 +45,34 @@ def identity_resource( yield data run_1 = [ - {"id": 1, "child": [{"bar": 1}, {"bar": 2}]}, - {"id": 2, "child": [{"bar": 3}]}, - {"id": 3, "child": [{"bar": 10}, {"bar": 11}]}, + { + "id": 1, + "child": [ + {"bar": 1, "grandchild": [{"baz": 1}, {"baz": 2}]}, + {"bar": 2, "grandchild": [{"baz": 3}]}, + ], + }, + {"id": 2, "child": [{"bar": 3, "grandchild": [{"baz": 4}]}]}, + { + "id": 3, + "child": [ + {"bar": 10, "grandchild": [{"baz": 5}]}, + {"bar": 11, "grandchild": [{"baz": 6}, {"baz": 7}]}, + ], + }, ] info = pipeline.run(identity_resource(run_1)) assert_load_info(info) run_2 = [ - {"id": 1, "child": [{"bar": 1}]}, # Removed one child. - {"id": 2, "child": [{"bar": 4}]}, # Changed child. + { + "id": 1, + "child": [{"bar": 1, "grandchild": [{"baz": 1}]}], + }, # Removed one child and one grandchild + { + "id": 2, + "child": [{"bar": 4, "grandchild": [{"baz": 8}]}], + }, # Changed child and grandchild ] info = pipeline.run(identity_resource(run_2)) assert_load_info(info) @@ -69,15 +87,48 @@ def identity_resource( ] ) - embeddings_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] + expected_grandchild_data = pd.DataFrame( + data=[ + {"baz": 1}, + {"baz": 8}, + {"baz": 5}, + {"baz": 6}, + {"baz": 7}, + ] + ) + + child_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] + grandchild_table_name = client.make_qualified_table_name( # type: ignore[attr-defined] + "parent__child__grandchild" + ) - tbl = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] - actual_df = tbl.to_pandas() + child_tbl = client.db_client.open_table(child_table_name) # type: ignore[attr-defined] + grandchild_tbl = client.db_client.open_table(grandchild_table_name) # type: ignore[attr-defined] + + actual_child_df = ( + child_tbl.to_pandas() + .sort_values(by="bar") + .reset_index(drop=True) + .reset_index(drop=True) + ) + actual_grandchild_df = ( + grandchild_tbl.to_pandas() + .sort_values(by="baz") + .reset_index(drop=True) + .reset_index(drop=True) + ) - expected_child_data = expected_child_data.sort_values(by="bar") - actual_df = actual_df.sort_values(by="bar").reset_index(drop=True) + expected_child_data = expected_child_data.sort_values(by="bar").reset_index( + drop=True + ) + expected_grandchild_data = expected_grandchild_data.sort_values( + by="baz" + ).reset_index(drop=True) - assert_frame_equal(actual_df[["bar"]], expected_child_data) + assert_frame_equal(actual_child_df[["bar"]], expected_child_data) + print(actual_grandchild_df[["baz"]]) + print(expected_grandchild_data) + assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) def test_lancedb_remove_orphaned_records_root_table() -> None: @@ -118,14 +169,18 @@ def identity_resource( assert_load_info(info) with pipeline.destination_client() as client: - expected_root_table_df = pd.DataFrame( - data=[ - {"doc_id": 1, "chunk_hash": "1a"}, - {"doc_id": 2, "chunk_hash": "2d"}, - {"doc_id": 2, "chunk_hash": "2e"}, - {"doc_id": 3, "chunk_hash": "3b"}, - ] - ).sort_values(by=["doc_id", "chunk_hash"]) + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) root_table_name = ( client.make_qualified_table_name("root") # type: ignore[attr-defined] From 6b363d1b6601d795e71a8a29b8501c4958cf3d69 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 1 Aug 2024 22:20:53 +0200 Subject: [PATCH 023/113] Use doc id hint for top level tables Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 178 ++++++++++++++---- dlt/destinations/impl/lancedb/utils.py | 4 +- .../lancedb/test_remove_orphaned_records.py | 15 +- 3 files changed, 149 insertions(+), 48 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 40da46f199..d2cc0982fb 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -12,6 +12,7 @@ Dict, Sequence, TYPE_CHECKING, + Set, ) import lancedb # type: ignore @@ -22,7 +23,7 @@ from lancedb.query import LanceQueryBuilder # type: ignore from lancedb.table import Table # type: ignore from numpy import ndarray -from pyarrow import Array, ChunkedArray, ArrowInvalid +from pyarrow import Array, ChunkedArray, ArrowInvalid, Table from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext @@ -58,7 +59,10 @@ from dlt.destinations.impl.lancedb.exceptions import ( lancedb_error, ) -from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + VECTORIZE_HINT, + DOCUMENT_ID_HINT, +) from dlt.destinations.impl.lancedb.schema import ( make_arrow_field_schema, make_arrow_table_schema, @@ -80,7 +84,9 @@ NDArray = ndarray TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = { + v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items() +} class LanceDBTypeMapper(TypeMapper): @@ -187,7 +193,9 @@ def upload_batch( tbl.add(records, mode="overwrite") elif write_disposition == "merge": if not id_field_name: - raise ValueError("To perform a merge update, 'id_field_name' must be specified.") + raise ValueError( + "To perform a merge update, 'id_field_name' must be specified." + ) tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) @@ -276,7 +284,9 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + def create_table( + self, table_name: str, schema: TArrowSchema, mode: str = "create" + ) -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -299,7 +309,9 @@ def delete_table(self, table_name: str) -> None: def query_table( self, table_name: str, - query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, + query: Union[ + List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None + ] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. @@ -327,7 +339,11 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name != self.sentinel_table] + return [ + table_name + for table_name in table_names + if table_name != self.sentinel_table + ] @lancedb_error def drop_storage(self) -> None: @@ -380,7 +396,9 @@ def update_stored_schema( applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + schema_info = self.get_stored_schema_by_hash( + self.schema.stored_version_hash + ) except DestinationUndefinedEntity: schema_info = None @@ -435,19 +453,25 @@ def add_table_fields( # Check if any of the new fields already exist in the table. existing_fields = set(arrow_table.schema.names) - new_fields = [field for field in field_schemas if field.name not in existing_fields] + new_fields = [ + field for field in field_schemas if field.name not in existing_fields + ] if not new_fields: # All fields already present, skip. return None - null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + null_arrays = [ + pa.nulls(len(arrow_table), type=field.type) for field in new_fields + ] for field, null_array in zip(new_fields, null_arrays): arrow_table = arrow_table.append_column(field, null_array) try: - return self.db_client.create_table(table_name, arrow_table, mode="overwrite") + return self.db_client.create_table( + table_name, arrow_table, mode="overwrite" + ) except OSError: # Error occurred while creating the table, skip. return None @@ -460,11 +484,15 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") + logger.info( + f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" + ) if len(new_columns) > 0: if exists: field_schemas: List[TArrowField] = [ - make_arrow_field_schema(column["name"], column, self.type_mapper) + make_arrow_field_schema( + column["name"], column, self.type_mapper + ) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) @@ -477,7 +505,9 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func - embedding_model_dimensions = self.config.embedding_model_dimensions + embedding_model_dimensions = ( + self.config.embedding_model_dimensions + ) else: embedding_fields = None vector_field_name = None @@ -508,8 +538,12 @@ def update_schema_in_storage(self) -> None: self.schema.naming.normalize_identifier( "engine_version" ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("inserted_at"): str( + pendulum.now() + ), + self.schema.naming.normalize_identifier( + "schema_name" + ): self.schema.name, self.schema.naming.normalize_identifier( "version_hash" ): self.schema.stored_version_hash, @@ -518,7 +552,9 @@ def update_schema_in_storage(self) -> None: ), } ] - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) @@ -533,8 +569,12 @@ def update_schema_in_storage(self) -> None: @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Retrieves the latest completed state for a pipeline.""" - fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_state_table_name = self.make_qualified_table_name( + self.schema.state_table_name + ) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) state_table_: Table = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() @@ -560,7 +600,9 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) .to_arrow() ) - loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + loads_table = ( + loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + ) # Join arrow tables in-memory. joined_table: pa.Table = state_table.join( @@ -582,8 +624,12 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) @lancedb_error - def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + def get_stored_schema_by_hash( + self, schema_hash: str + ) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -601,7 +647,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI ) ).to_list() - most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + most_recent_schema = sorted( + schemas, key=lambda x: x[p_inserted_at], reverse=True + )[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -616,7 +664,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI @lancedb_error def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -634,7 +684,9 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: ) ).to_list() - most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + most_recent_schema = sorted( + schemas, key=lambda x: x[p_inserted_at], reverse=True + )[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -662,15 +714,21 @@ def complete_load(self, load_id: str) -> None: records = [ { self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier( + "schema_name" + ): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("inserted_at"): str( + pendulum.now() + ), self.schema.naming.normalize_identifier( "schema_version_hash" ): None, # Payload schema must match the target schema. } ] - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) @@ -684,7 +742,9 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: parent_table = table.get("parent") return LoadLanceDBJob( @@ -712,6 +772,9 @@ def create_table_chain_completed_followup_jobs( ) for table in table_chain: + if table in self.schema.dlt_tables(): + continue + # Only tables with merge disposition are dispatched for orphan removal jobs. if table.get("write_disposition") == "merge": parent_table = table.get("parent") @@ -721,8 +784,11 @@ def create_table_chain_completed_followup_jobs( table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name(table["name"]), fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None + self.make_qualified_table_name(parent_table) + if parent_table + else None ), + client_config=self.config, ) ) @@ -756,8 +822,10 @@ def __init__( self.table_name: str = table_schema["name"] self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) + self.unique_identifiers: List[str] = list_merge_identifiers(table_schema) + self.embedding_fields: List[str] = get_columns_names_with_prop( + table_schema, VECTORIZE_HINT + ) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name @@ -804,6 +872,7 @@ def __init__( self, db_client: DBConnection, table_schema: TTableSchema, + client_config: LanceDBClientConfiguration, fq_table_name: str, fq_parent_table_name: Optional[str], ) -> None: @@ -814,6 +883,7 @@ def __init__( self.write_disposition: TWriteDisposition = cast( TWriteDisposition, self.table_schema.get("write_disposition") ) + self.id_field_name: str = client_config.id_field_name job_id = ParsedLoadJobFileName( table_schema["name"], @@ -832,6 +902,8 @@ def __init__( self.execute() def execute(self) -> None: + orphaned_ids: Set[str] + if self.write_disposition != "merge": raise DestinationTerminalException( f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" @@ -851,22 +923,50 @@ def execute(self) -> None: ) from e try: - # Chunks in child table. if self.fq_parent_table_name: - parent_ids = set(pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist()) - child_ids = set(pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist()) + # Chunks and embeddings in child table. + parent_ids = set( + pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist() + ) + child_ids = set( + pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist() + ) if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: child_table.delete( - "_dlt_parent_id IN" - f" {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}" + "_dlt_parent_id IN" f" {tuple(orphaned_ids)}" ) elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") - # Chunks in root table. TODO: Add test for embeddings in root table - else: ... + else: + # Chunks and embeddings in the root table. + child_table_arrow: pa.Table = child_table.to_arrow() + + # If document ID is defined, we use this as the sole grouping key to identify stale chunks, + # else fallback to the compound `id_field_name`. + grouping_key = ( + get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) + or self.id_field_name + ) + + grouped = child_table_arrow.group_by(grouping_key).aggregate( + [("_dlt_load_id", "max")] + ) + joined = child_table_arrow.join(grouped, keys=grouping_key) + orphaned_mask = pc.not_equal( + joined["_dlt_load_id"], joined["_dlt_load_id_max"] + ) + orphaned_ids = ( + joined.filter(orphaned_mask).column("_dlt_id").to_pylist() + ) + + if len(orphaned_ids) > 1: + child_table.delete("_dlt_id IN" f" {tuple(orphaned_ids)}") + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") + except ArrowInvalid as e: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 874852512a..1d944ff7a4 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,6 +1,6 @@ import os import uuid -from typing import Sequence, Union, Dict +from typing import Sequence, Union, Dict, List from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop @@ -32,7 +32,7 @@ def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_nam return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) -def list_merge_identifiers(table_schema: TTableSchema) -> Sequence[str]: +def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: """Returns a list of merge keys for a table used for either merging or deduplication. Args: diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 9104399c9b..4679238010 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -8,6 +8,7 @@ import dlt from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id +from dlt.destinations.impl.lancedb.lancedb_adapter import DOCUMENT_ID_HINT from tests.load.utils import ( drop_active_pipeline_data, ) @@ -34,10 +35,10 @@ def test_lancedb_remove_orphaned_records() -> None: dev_mode=True, ) - @dlt.resource( + @dlt.resource( # type: ignore[call-overload] table_name="parent", write_disposition="merge", - primary_key="id", + columns={"id": {DOCUMENT_ID_HINT: True}}, ) def identity_resource( data: List[DictStrAny], @@ -126,8 +127,6 @@ def identity_resource( ).reset_index(drop=True) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) - print(actual_grandchild_df[["baz"]]) - print(expected_grandchild_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -139,10 +138,11 @@ def test_lancedb_remove_orphaned_records_root_table() -> None: dev_mode=True, ) - @dlt.resource( + @dlt.resource( # type: ignore[call-overload] table_name="root", write_disposition="merge", - primary_key="doc_id", + merge_key=["chunk_hash"], + columns={"doc_id": {DOCUMENT_ID_HINT: True}}, ) def identity_resource( data: List[DictStrAny], @@ -191,5 +191,6 @@ def identity_resource( tbl.to_pandas() .sort_values(by=["doc_id", "chunk_hash"]) .reset_index(drop=True) - ) + )[["doc_id", "chunk_hash"]] + assert_frame_equal(actual_root_df, expected_root_table_df) From aac7647ed7ac28cfd67f768ce044d34d452dd181 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 1 Aug 2024 23:03:55 +0200 Subject: [PATCH 024/113] Only join on join columns for orphan removal job Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 145 +++++------------- dlt/destinations/impl/lancedb/utils.py | 4 +- .../lancedb/test_remove_orphaned_records.py | 14 +- 3 files changed, 44 insertions(+), 119 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index d2cc0982fb..cc35aadf18 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -84,9 +84,7 @@ NDArray = ndarray TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = { - v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items() -} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} class LanceDBTypeMapper(TypeMapper): @@ -193,9 +191,7 @@ def upload_batch( tbl.add(records, mode="overwrite") elif write_disposition == "merge": if not id_field_name: - raise ValueError( - "To perform a merge update, 'id_field_name' must be specified." - ) + raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) @@ -284,9 +280,7 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table( - self, table_name: str, schema: TArrowSchema, mode: str = "create" - ) -> Table: + def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -309,9 +303,7 @@ def delete_table(self, table_name: str) -> None: def query_table( self, table_name: str, - query: Union[ - List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None - ] = None, + query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. @@ -339,11 +331,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [ - table_name - for table_name in table_names - if table_name != self.sentinel_table - ] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -396,9 +384,7 @@ def update_stored_schema( applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash( - self.schema.stored_version_hash - ) + schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) except DestinationUndefinedEntity: schema_info = None @@ -453,25 +439,19 @@ def add_table_fields( # Check if any of the new fields already exist in the table. existing_fields = set(arrow_table.schema.names) - new_fields = [ - field for field in field_schemas if field.name not in existing_fields - ] + new_fields = [field for field in field_schemas if field.name not in existing_fields] if not new_fields: # All fields already present, skip. return None - null_arrays = [ - pa.nulls(len(arrow_table), type=field.type) for field in new_fields - ] + null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] for field, null_array in zip(new_fields, null_arrays): arrow_table = arrow_table.append_column(field, null_array) try: - return self.db_client.create_table( - table_name, arrow_table, mode="overwrite" - ) + return self.db_client.create_table(table_name, arrow_table, mode="overwrite") except OSError: # Error occurred while creating the table, skip. return None @@ -484,15 +464,11 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - logger.info( - f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" - ) + logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: field_schemas: List[TArrowField] = [ - make_arrow_field_schema( - column["name"], column, self.type_mapper - ) + make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) @@ -505,9 +481,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func - embedding_model_dimensions = ( - self.config.embedding_model_dimensions - ) + embedding_model_dimensions = self.config.embedding_model_dimensions else: embedding_fields = None vector_field_name = None @@ -538,12 +512,8 @@ def update_schema_in_storage(self) -> None: self.schema.naming.normalize_identifier( "engine_version" ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str( - pendulum.now() - ), - self.schema.naming.normalize_identifier( - "schema_name" - ): self.schema.name, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier( "version_hash" ): self.schema.stored_version_hash, @@ -552,9 +522,7 @@ def update_schema_in_storage(self) -> None: ), } ] - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) @@ -569,12 +537,8 @@ def update_schema_in_storage(self) -> None: @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Retrieves the latest completed state for a pipeline.""" - fq_state_table_name = self.make_qualified_table_name( - self.schema.state_table_name - ) - fq_loads_table_name = self.make_qualified_table_name( - self.schema.loads_table_name - ) + fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) state_table_: Table = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() @@ -600,9 +564,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) .to_arrow() ) - loads_table = ( - loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() - ) + loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() # Join arrow tables in-memory. joined_table: pa.Table = state_table.join( @@ -624,12 +586,8 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) @lancedb_error - def get_stored_schema_by_hash( - self, schema_hash: str - ) -> Optional[StorageSchemaInfo]: - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -647,9 +605,7 @@ def get_stored_schema_by_hash( ) ).to_list() - most_recent_schema = sorted( - schemas, key=lambda x: x[p_inserted_at], reverse=True - )[0] + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -664,9 +620,7 @@ def get_stored_schema_by_hash( @lancedb_error def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -684,9 +638,7 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: ) ).to_list() - most_recent_schema = sorted( - schemas, key=lambda x: x[p_inserted_at], reverse=True - )[0] + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -714,21 +666,15 @@ def complete_load(self, load_id: str) -> None: records = [ { self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier( - "schema_name" - ): self.schema.name, + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str( - pendulum.now() - ), + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), self.schema.naming.normalize_identifier( "schema_version_hash" ): None, # Payload schema must match the target schema. } ] - fq_loads_table_name = self.make_qualified_table_name( - self.schema.loads_table_name - ) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) @@ -742,9 +688,7 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load( - self, table: TTableSchema, file_path: str, load_id: str - ) -> LoadJob: + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: parent_table = table.get("parent") return LoadLanceDBJob( @@ -784,9 +728,7 @@ def create_table_chain_completed_followup_jobs( table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name(table["name"]), fq_parent_table_name=( - self.make_qualified_table_name(parent_table) - if parent_table - else None + self.make_qualified_table_name(parent_table) if parent_table else None ), client_config=self.config, ) @@ -823,9 +765,7 @@ def __init__( self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.unique_identifiers: List[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop( - table_schema, VECTORIZE_HINT - ) + self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name @@ -925,24 +865,17 @@ def execute(self) -> None: try: if self.fq_parent_table_name: # Chunks and embeddings in child table. - parent_ids = set( - pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist() - ) - child_ids = set( - pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist() - ) + parent_ids = set(pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist()) + child_ids = set(pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist()) if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: - child_table.delete( - "_dlt_parent_id IN" f" {tuple(orphaned_ids)}" - ) + child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: # Chunks and embeddings in the root table. - child_table_arrow: pa.Table = child_table.to_arrow() # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. @@ -950,20 +883,20 @@ def execute(self) -> None: get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) or self.id_field_name ) + grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] + child_table_arrow: pa.Table = child_table.to_arrow().select( + [*grouping_key, "_dlt_load_id", "_dlt_id"] + ) grouped = child_table_arrow.group_by(grouping_key).aggregate( [("_dlt_load_id", "max")] ) joined = child_table_arrow.join(grouped, keys=grouping_key) - orphaned_mask = pc.not_equal( - joined["_dlt_load_id"], joined["_dlt_load_id_max"] - ) - orphaned_ids = ( - joined.filter(orphaned_mask).column("_dlt_id").to_pylist() - ) + orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) + orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() if len(orphaned_ids) > 1: - child_table.delete("_dlt_id IN" f" {tuple(orphaned_ids)}") + child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 1d944ff7a4..f202903598 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -6,7 +6,6 @@ from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider -from dlt.destinations.impl.lancedb.lancedb_adapter import DOCUMENT_ID_HINT PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -42,10 +41,9 @@ def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: Sequence[str]: A list of unique column identifiers. """ if table_schema.get("write_disposition") == "merge": - document_id = get_columns_names_with_prop(table_schema, DOCUMENT_ID_HINT) primary_keys = get_columns_names_with_prop(table_schema, "primary_key") merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys + document_id)): + if join_keys := list(set(primary_keys + merge_keys)): return join_keys return get_columns_names_with_prop(table_schema, "unique") diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 4679238010..285ecec577 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -119,12 +119,10 @@ def identity_resource( .reset_index(drop=True) ) - expected_child_data = expected_child_data.sort_values(by="bar").reset_index( + expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) + expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( drop=True ) - expected_grandchild_data = expected_grandchild_data.sort_values( - by="baz" - ).reset_index(drop=True) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -182,15 +180,11 @@ def identity_resource( .reset_index(drop=True) ) - root_table_name = ( - client.make_qualified_table_name("root") # type: ignore[attr-defined] - ) + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] actual_root_df: DataFrame = ( - tbl.to_pandas() - .sort_values(by=["doc_id", "chunk_hash"]) - .reset_index(drop=True) + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) )[["doc_id", "chunk_hash"]] assert_frame_equal(actual_root_df, expected_root_table_df) From e33b7cffe95e444ced4300eeec84add4a2d92bad Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 1 Aug 2024 23:51:15 +0200 Subject: [PATCH 025/113] Add ollama to supported embedding providers and test orphaned record removal with embeddings Signed-off-by: Marcel Coetzee --- .../impl/lancedb/configuration.py | 1 + .../lancedb/test_remove_orphaned_records.py | 95 ++++++++++++++++++- 2 files changed, 91 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index ba3a8b49d9..5aa4ba714f 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -59,6 +59,7 @@ class LanceDBClientOptions(BaseConfiguration): "sentence-transformers", "huggingface", "colbert", + "ollama", ] diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 285ecec577..2c25a315d4 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -1,14 +1,21 @@ -from typing import Iterator, List, Generator +from typing import Iterator, List, Generator, Any +import numpy as np import pandas as pd import pytest +from lancedb.table import Table # type: ignore from pandas import DataFrame from pandas.testing import assert_frame_equal +from pyarrow import Table import dlt from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id -from dlt.destinations.impl.lancedb.lancedb_adapter import DOCUMENT_ID_HINT +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + DOCUMENT_ID_HINT, + lancedb_adapter, +) +from tests.load.lancedb.utils import chunk_document from tests.load.utils import ( drop_active_pipeline_data, ) @@ -119,10 +126,12 @@ def identity_resource( .reset_index(drop=True) ) - expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) - expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( + expected_child_data = expected_child_data.sort_values(by="bar").reset_index( drop=True ) + expected_grandchild_data = expected_grandchild_data.sort_values( + by="baz" + ).reset_index(drop=True) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -184,7 +193,83 @@ def identity_resource( tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] actual_root_df: DataFrame = ( - tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) + tbl.to_pandas() + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) )[["doc_id", "chunk_hash"]] assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: + @dlt.resource( # type: ignore + write_disposition="merge", + table_name="document", + columns={"doc_id": {DOCUMENT_ID_HINT: True}}, + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + for chunk in chunk_document(doc["text"]): + yield {"doc_id": doc_id, "doc_text": doc["text"], "chunk": chunk} + + @dlt.source() + def documents_source( + docs: List[DictStrAny], + ) -> Any: + return documents(docs) + + lancedb_adapter( + documents, + embed=["chunk"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_with_embeddings", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_{uniq_id()}", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ + { + "text": "This is the first document, but it has been updated with new content.", + "doc_id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "doc_id": 3, + }, + ] + + info = pipeline.run(documents_source(updated_docs)) + assert_load_info(info) + + with pipeline.destination_client() as client: + embeddings_table_name = client.make_qualified_table_name("document") # type: ignore[attr-defined] + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + + # Check (non-empty) embeddings as present, and that orphaned embeddings have been discarded. + assert len(df) == 21 + assert "vector__" in df.columns + for _, vector in enumerate(df["vector__"]): + assert isinstance(vector, np.ndarray) + assert vector.size > 0 From f2913e9324af4a149a6579d68fa9676ae3ef7ca6 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 2 Aug 2024 19:36:24 +0200 Subject: [PATCH 026/113] Add merge_key to document resource for efficient updates in LanceDB Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_remove_orphaned_records.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 2c25a315d4..f4352ed71e 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -205,6 +205,7 @@ def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> No @dlt.resource( # type: ignore write_disposition="merge", table_name="document", + merge_key=["chunk"], columns={"doc_id": {DOCUMENT_ID_HINT: True}}, ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: From ffe6584a412e047264f69d4d58f6c25ec9bf1908 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 2 Aug 2024 19:51:53 +0200 Subject: [PATCH 027/113] Formatting Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_remove_orphaned_records.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index f4352ed71e..3118b06cc7 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -126,12 +126,10 @@ def identity_resource( .reset_index(drop=True) ) - expected_child_data = expected_child_data.sort_values(by="bar").reset_index( + expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) + expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( drop=True ) - expected_grandchild_data = expected_grandchild_data.sort_values( - by="baz" - ).reset_index(drop=True) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -193,9 +191,7 @@ def identity_resource( tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] actual_root_df: DataFrame = ( - tbl.to_pandas() - .sort_values(by=["doc_id", "chunk_hash"]) - .reset_index(drop=True) + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) )[["doc_id", "chunk_hash"]] assert_frame_equal(actual_root_df, expected_root_table_df) From 036801829879d34d54052755560debb2f6c3e1f8 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 2 Aug 2024 20:03:09 +0200 Subject: [PATCH 028/113] Set default file size to 128MB Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index f2e17168b9..714c8dcdfd 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -30,6 +30,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.decimal_precision = (38, 18) caps.timestamp_precision = 6 + caps.recommended_file_size = 128_000_000 + return caps @property From 02704d50d60dece8e5cc1abf6aa5b63edd561a7a Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 3 Aug 2024 23:37:34 +0200 Subject: [PATCH 029/113] Only use parquet loader file formats Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 4 ++-- dlt/destinations/impl/lancedb/lancedb_client.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index 714c8dcdfd..9acb82344b 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -16,8 +16,8 @@ class lancedb(Destination[LanceDBClientConfiguration, "LanceDBClient"]): def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl"] + caps.preferred_loader_file_format = "parquet" + caps.supported_loader_file_formats = ["parquet"] caps.max_identifier_length = 200 caps.max_column_identifier_length = 1024 diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index cc35aadf18..76c3feb032 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -773,8 +773,8 @@ def __init__( TWriteDisposition, self.table_schema.get("write_disposition", "append") ) - with FileStorage.open_zipsafe_ro(local_path) as f: - records: List[DictStrAny] = [json.loads(line) for line in f] + with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: + arrow_table: pa.Table = pq.read_table(f, memory_map=True) if self.table_schema not in self.schema.dlt_tables(): for record in records: From eae056a5ea0f7c0ef738a0a835fb1ffab260d412 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 4 Aug 2024 17:58:01 +0200 Subject: [PATCH 030/113] Import pyarrow.parquet Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 523 ++++-------------- 1 file changed, 106 insertions(+), 417 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 76c3feb032..04f0e2cac0 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,23 +1,11 @@ import uuid from types import TracebackType -from typing import ( - List, - Any, - cast, - Union, - Tuple, - Iterable, - Type, - Optional, - Dict, - Sequence, - TYPE_CHECKING, - Set, -) +from typing import (List, Any, cast, Union, Tuple, Iterable, Type, Optional, Dict, Sequence, TYPE_CHECKING, Set, ) import lancedb # type: ignore import pyarrow as pa import pyarrow.compute as pc +import pyarrow.parquet as pq from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore @@ -27,54 +15,23 @@ from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.exceptions import ( - DestinationUndefinedEntity, - DestinationTransientException, - DestinationTerminalException, -) -from dlt.common.destination.reference import ( - JobClientBase, - WithStateSync, - LoadJob, - StorageSchemaInfo, - StateInfo, - TLoadJobState, - NewLoadJob, - FollowupJob, -) +from dlt.common.destination.exceptions import (DestinationUndefinedEntity, DestinationTransientException, + DestinationTerminalException, ) +from dlt.common.destination.reference import (JobClientBase, WithStateSync, LoadJob, StorageSchemaInfo, StateInfo, + TLoadJobState, NewLoadJob, FollowupJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.typing import ( - TColumnType, - TTableFormat, - TTableSchemaColumns, - TWriteDisposition, -) +from dlt.common.schema.typing import (TColumnType, TTableFormat, TTableSchemaColumns, TWriteDisposition, ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny -from dlt.destinations.impl.lancedb.configuration import ( - LanceDBClientConfiguration, -) -from dlt.destinations.impl.lancedb.exceptions import ( - lancedb_error, -) -from dlt.destinations.impl.lancedb.lancedb_adapter import ( - VECTORIZE_HINT, - DOCUMENT_ID_HINT, -) -from dlt.destinations.impl.lancedb.schema import ( - make_arrow_field_schema, - make_arrow_table_schema, - TArrowSchema, - NULL_SCHEMA, - TArrowField, -) -from dlt.destinations.impl.lancedb.utils import ( - list_merge_identifiers, - generate_uuid, - set_non_standard_providers_environment_variables, -) +from dlt.destinations.impl.lancedb.configuration import (LanceDBClientConfiguration, ) +from dlt.destinations.impl.lancedb.exceptions import (lancedb_error, ) +from dlt.destinations.impl.lancedb.lancedb_adapter import (VECTORIZE_HINT, DOCUMENT_ID_HINT, ) +from dlt.destinations.impl.lancedb.schema import (make_arrow_field_schema, make_arrow_table_schema, TArrowSchema, + NULL_SCHEMA, TArrowField, ) +from dlt.destinations.impl.lancedb.utils import (list_merge_identifiers, generate_uuid, + set_non_standard_providers_environment_variables, ) from dlt.destinations.job_impl import EmptyLoadJob, NewLoadJobImpl from dlt.destinations.type_mapping import TypeMapper @@ -88,80 +45,38 @@ class LanceDBTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "text": pa.string(), - "double": pa.float64(), - "bool": pa.bool_(), - "bigint": pa.int64(), - "binary": pa.binary(), - "date": pa.date32(), - "complex": pa.string(), - } + sct_to_unbound_dbt = {"text": pa.string(), "double": pa.float64(), "bool": pa.bool_(), "bigint": pa.int64(), "binary": pa.binary(), "date": pa.date32(), "complex": pa.string(), } sct_to_dbt = {} - dbt_to_sct = { - pa.string(): "text", - pa.float64(): "double", - pa.bool_(): "bool", - pa.int64(): "bigint", - pa.binary(): "binary", - pa.date32(): "date", - } - - def to_db_decimal_type( - self, precision: Optional[int], scale: Optional[int] - ) -> pa.Decimal128Type: + dbt_to_sct = {pa.string(): "text", pa.float64(): "double", pa.bool_(): "bool", pa.int64(): "bigint", pa.binary(): "binary", pa.date32(): "date", } + + def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> pa.Decimal128Type: precision, scale = self.decimal_precision(precision, scale) return pa.decimal128(precision, scale) - def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> pa.TimestampType: + def to_db_datetime_type(self, precision: Optional[int], table_format: TTableFormat = None) -> pa.TimestampType: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.timestamp(unit, "UTC") - def to_db_time_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> pa.Time64Type: + def to_db_time_type(self, precision: Optional[int], table_format: TTableFormat = None) -> pa.Time64Type: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.time64(unit) - def from_db_type( - self, - db_type: pa.DataType, - precision: Optional[int] = None, - scale: Optional[int] = None, - ) -> TColumnType: + def from_db_type(self, db_type: pa.DataType, precision: Optional[int] = None, scale: Optional[int] = None, ) -> TColumnType: if isinstance(db_type, pa.TimestampType): - return dict( - data_type="timestamp", - precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], - scale=scale, - ) + return dict(data_type="timestamp", precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], scale=scale, ) if isinstance(db_type, pa.Time64Type): - return dict( - data_type="time", - precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], - scale=scale, - ) + return dict(data_type="time", precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], scale=scale, ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale) == self.capabilities.wei_precision: + if (precision, scale)==self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) -def upload_batch( - records: List[DictStrAny], - /, - *, - db_client: DBConnection, - table_name: str, - write_disposition: Optional[TWriteDisposition] = "append", - id_field_name: Optional[str] = None, -) -> None: +def upload_batch(records: List[DictStrAny], /, *, db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. Args: @@ -180,30 +95,22 @@ def upload_batch( tbl = db_client.open_table(table_name) tbl.checkout_latest() except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Batch WILL BE RETRIED" - ) from e + raise DestinationTransientException("Couldn't open lancedb database. Batch WILL BE RETRIED") from e try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition == "replace": + elif write_disposition=="replace": tbl.add(records, mode="overwrite") - elif write_disposition == "merge": + elif write_disposition=="merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") - tbl.merge_insert( - id_field_name - ).when_matched_update_all().when_not_matched_insert_all().execute(records) + tbl.merge_insert(id_field_name).when_matched_update_all().when_not_matched_insert_all().execute(records) else: - raise DestinationTerminalException( - f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" - " failed AND WILL **NOT** BE RETRIED." - ) + raise DestinationTerminalException(f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" + " failed AND WILL **NOT** BE RETRIED.") except ArrowInvalid as e: - raise DestinationTerminalException( - "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." - ) from e + raise DestinationTerminalException("Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED.") from e class LanceDBClient(JobClientBase, WithStateSync): @@ -211,19 +118,10 @@ class LanceDBClient(JobClientBase, WithStateSync): model_func: TextEmbeddingFunction - def __init__( - self, - schema: Schema, - config: LanceDBClientConfiguration, - capabilities: DestinationCapabilitiesContext, - ) -> None: + def __init__(self, schema: Schema, config: LanceDBClientConfiguration, capabilities: DestinationCapabilitiesContext, ) -> None: super().__init__(schema, config, capabilities) self.config: LanceDBClientConfiguration = config - self.db_client: DBConnection = lancedb.connect( - uri=self.config.credentials.uri, - api_key=self.config.credentials.api_key, - read_consistency_interval=timedelta(0), - ) + self.db_client: DBConnection = lancedb.connect(uri=self.config.credentials.uri, api_key=self.config.credentials.api_key, read_consistency_interval=timedelta(0), ) self.registry = EmbeddingFunctionRegistry.get_instance() self.type_mapper = LanceDBTypeMapper(self.capabilities) self.sentinel_table_name = config.sentinel_table_name @@ -233,24 +131,14 @@ def __init__( # LanceDB doesn't provide a standardized way to set API keys across providers. # Some use ENV variables and others allow passing api key as an argument. # To account for this, we set provider environment variable as well. - set_non_standard_providers_environment_variables( - embedding_model_provider, - self.config.credentials.embedding_model_provider_api_key, - ) + set_non_standard_providers_environment_variables(embedding_model_provider, self.config.credentials.embedding_model_provider_api_key, ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider == "openai": + if embedding_model_provider=="openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings - self.model_func = PatchedOpenAIEmbeddings( - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) + self.model_func = PatchedOpenAIEmbeddings(max_retries=self.config.options.max_retries, api_key=self.config.credentials.api_key, ) else: - self.model_func = self.registry.get(embedding_model_provider).create( - name=self.config.embedding_model, - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) + self.model_func = self.registry.get(embedding_model_provider).create(name=self.config.embedding_model, max_retries=self.config.options.max_retries, api_key=self.config.credentials.api_key, ) self.vector_field_name = self.config.vector_field_name self.id_field_name = self.config.id_field_name @@ -264,20 +152,13 @@ def sentinel_table(self) -> str: return self.make_qualified_table_name(self.sentinel_table_name) def make_qualified_table_name(self, table_name: str) -> str: - return ( - f"{self.dataset_name}{self.config.dataset_separator}{table_name}" - if self.dataset_name - else table_name - ) + return (f"{self.dataset_name}{self.config.dataset_separator}{table_name}" if self.dataset_name else table_name) def get_table_schema(self, table_name: str) -> TArrowSchema: schema_table: Table = self.db_client.open_table(table_name) schema_table.checkout_latest() schema = schema_table.schema - return cast( - TArrowSchema, - schema, - ) + return cast(TArrowSchema, schema, ) @lancedb_error def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: @@ -300,11 +181,7 @@ def delete_table(self, table_name: str) -> None: """ self.db_client.drop_table(table_name) - def query_table( - self, - table_name: str, - query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, - ) -> LanceQueryBuilder: + def query_table(self, table_name: str, query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. Args: @@ -323,15 +200,11 @@ def _get_table_names(self) -> List[str]: """Return all tables in the dataset, excluding the sentinel table.""" if self.dataset_name: prefix = f"{self.dataset_name}{self.config.dataset_separator}" - table_names = [ - table_name - for table_name in self.db_client.table_names() - if table_name.startswith(prefix) - ] + table_names = [table_name for table_name in self.db_client.table_names() if table_name.startswith(prefix)] else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name != self.sentinel_table] + return [table_name for table_name in table_names if table_name!=self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -357,10 +230,7 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: continue schema = self.get_table_schema(fq_table_name) self.db_client.drop_table(fq_table_name) - self.create_table( - table_name=fq_table_name, - schema=schema, - ) + self.create_table(table_name=fq_table_name, schema=schema, ) @lancedb_error def is_storage_initialized(self) -> bool: @@ -375,11 +245,7 @@ def _delete_sentinel_table(self) -> None: self.db_client.drop_table(self.sentinel_table) @lancedb_error - def update_stored_schema( - self, - only_tables: Iterable[str] = None, - expected_update: TSchemaTables = None, - ) -> Optional[TSchemaTables]: + def update_stored_schema(self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None, ) -> Optional[TSchemaTables]: super().update_stored_schema(only_tables, expected_update) applied_update: TSchemaTables = {} @@ -389,17 +255,13 @@ def update_stored_schema( schema_info = None if schema_info is None: - logger.info( - f"Schema with hash {self.schema.stored_version_hash} " - "not found in the storage. upgrading" - ) + logger.info(f"Schema with hash {self.schema.stored_version_hash} " + "not found in the storage. upgrading") self._execute_schema_update(only_tables) else: - logger.info( - f"Schema with hash {self.schema.stored_version_hash} " - f"inserted at {schema_info.inserted_at} found " - "in storage, no upgrade required" - ) + logger.info(f"Schema with hash {self.schema.stored_version_hash} " + f"inserted at {schema_info.inserted_at} found " + "in storage, no upgrade required") return applied_update def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: @@ -417,16 +279,11 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] field: TArrowField for field in arrow_schema: name = self.schema.naming.normalize_identifier(field.name) - table_schema[name] = { - "name": name, - **self.type_mapper.from_db_type(field.type), - } + table_schema[name] = {"name": name, **self.type_mapper.from_db_type(field.type), } return True, table_schema @lancedb_error - def add_table_fields( - self, table_name: str, field_schemas: List[TArrowField] - ) -> Optional[Table]: + def add_table_fields(self, table_name: str, field_schemas: List[TArrowField]) -> Optional[Table]: """Add multiple fields to the LanceDB table at once. Args: @@ -459,25 +316,16 @@ def add_table_fields( def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns( - table_name, - existing_columns, - self.capabilities.generates_case_sensitive_identifiers(), - ) + new_columns = self.schema.get_new_table_columns(table_name, existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: - field_schemas: List[TArrowField] = [ - make_arrow_field_schema(column["name"], column, self.type_mapper) - for column in new_columns - ] + field_schemas: List[TArrowField] = [make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns] fq_table_name = self.make_qualified_table_name(table_name) self.add_table_fields(fq_table_name, field_schemas) else: if table_name not in self.schema.dlt_table_names(): - embedding_fields = get_columns_names_with_prop( - self.schema.get_table(table_name=table_name), VECTORIZE_HINT - ) + embedding_fields = get_columns_names_with_prop(self.schema.get_table(table_name=table_name), VECTORIZE_HINT) vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func @@ -489,16 +337,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: embedding_model_func = None embedding_model_dimensions = None - table_schema: TArrowSchema = make_arrow_table_schema( - table_name, - schema=self.schema, - type_mapper=self.type_mapper, - embedding_fields=embedding_fields, - embedding_model_func=embedding_model_func, - embedding_model_dimensions=embedding_model_dimensions, - vector_field_name=vector_field_name, - id_field_name=id_field_name, - ) + table_schema: TArrowSchema = make_arrow_table_schema(table_name, schema=self.schema, type_mapper=self.type_mapper, embedding_fields=embedding_fields, embedding_model_func=embedding_model_func, embedding_model_dimensions=embedding_model_dimensions, vector_field_name=vector_field_name, id_field_name=id_field_name, ) fq_table_name = self.make_qualified_table_name(table_name) self.create_table(fq_table_name, table_schema) @@ -506,33 +345,13 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: @lancedb_error def update_schema_in_storage(self) -> None: - records = [ - { - self.schema.naming.normalize_identifier("version"): self.schema.version, - self.schema.naming.normalize_identifier( - "engine_version" - ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, - self.schema.naming.normalize_identifier( - "version_hash" - ): self.schema.stored_version_hash, - self.schema.naming.normalize_identifier("schema"): json.dumps( - self.schema.to_dict() - ), - } - ] + records = [{self.schema.naming.normalize_identifier("version"): self.schema.version, self.schema.naming.normalize_identifier("engine_version"): self.schema.ENGINE_VERSION, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("version_hash"): self.schema.stored_version_hash, self.schema.naming.normalize_identifier("schema"): json.dumps(self.schema.to_dict()), }] fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - write_disposition = self.schema.get_table(self.schema.version_table_name).get( - "write_disposition" - ) + write_disposition = self.schema.get_table(self.schema.version_table_name).get("write_disposition") - upload_batch( - records, - db_client=self.db_client, - table_name=fq_version_table_name, - write_disposition=write_disposition, - ) + upload_batch(records, db_client=self.db_client, table_name=fq_version_table_name, write_disposition=write_disposition, ) @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: @@ -559,31 +378,18 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less # data into memory as possible. - state_table = ( - state_table_.search() - .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) - .to_arrow() - ) + state_table = (state_table_.search().where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True).to_arrow()) loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() # Join arrow tables in-memory. - joined_table: pa.Table = state_table.join( - loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" - ).sort_by([(p_dlt_load_id, "descending")]) + joined_table: pa.Table = state_table.join(loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner").sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows == 0: + if joined_table.num_rows==0: return None state = joined_table.take([0]).to_pylist()[0] - return StateInfo( - version=state[p_version], - engine_version=state[p_engine_version], - pipeline_name=state[p_pipeline_name], - state=state[p_state], - created_at=pendulum.instance(state[p_created_at]), - version_hash=state[p_version_hash], - _dlt_load_id=state[p_dlt_load_id], - ) + return StateInfo(version=state[p_version], engine_version=state[p_engine_version], pipeline_name=state[p_pipeline_name], state=state[p_state], created_at=pendulum.instance( + state[p_created_at]), version_hash=state[p_version_hash], _dlt_load_id=state[p_dlt_load_id], ) @lancedb_error def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: @@ -599,21 +405,11 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = ( - version_table.search().where( - f'`{p_version_hash}` = "{schema_hash}"', prefilter=True - ) - ).to_list() + schemas = (version_table.search().where(f'`{p_version_hash}` = "{schema_hash}"', prefilter=True)).to_list() most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] - return StorageSchemaInfo( - version_hash=most_recent_schema[p_version_hash], - schema_name=most_recent_schema[p_schema_name], - version=most_recent_schema[p_version], - engine_version=most_recent_schema[p_engine_version], - inserted_at=most_recent_schema[p_inserted_at], - schema=most_recent_schema[p_schema], - ) + return StorageSchemaInfo(version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], version=most_recent_schema[p_version], engine_version= + most_recent_schema[p_engine_version], inserted_at=most_recent_schema[p_inserted_at], schema=most_recent_schema[p_schema], ) except IndexError: return None @@ -632,30 +428,15 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = ( - version_table.search().where( - f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True - ) - ).to_list() + schemas = (version_table.search().where(f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True)).to_list() most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] - return StorageSchemaInfo( - version_hash=most_recent_schema[p_version_hash], - schema_name=most_recent_schema[p_schema_name], - version=most_recent_schema[p_version], - engine_version=most_recent_schema[p_engine_version], - inserted_at=most_recent_schema[p_inserted_at], - schema=most_recent_schema[p_schema], - ) + return StorageSchemaInfo(version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], version=most_recent_schema[p_version], engine_version= + most_recent_schema[p_engine_version], inserted_at=most_recent_schema[p_inserted_at], schema=most_recent_schema[p_schema], ) except IndexError: return None - def __exit__( - self, - exc_type: Type[BaseException], - exc_val: BaseException, - exc_tb: TracebackType, - ) -> None: + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType, ) -> None: pass def __enter__(self) -> "LanceDBClient": @@ -663,27 +444,13 @@ def __enter__(self) -> "LanceDBClient": @lancedb_error def complete_load(self, load_id: str) -> None: - records = [ - { - self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, - self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier( - "schema_version_hash" - ): None, # Payload schema must match the target schema. - } - ] + records = [{self.schema.naming.normalize_identifier("load_id"): load_id, self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("status"): 0, self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_version_hash"): None, # Payload schema must match the target schema. + }] fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) - write_disposition = self.schema.get_table(self.schema.loads_table_name).get( - "write_disposition" - ) - upload_batch( - records, - db_client=self.db_client, - table_name=fq_loads_table_name, - write_disposition=write_disposition, - ) + write_disposition = self.schema.get_table(self.schema.loads_table_name).get("write_disposition") + upload_batch(records, db_client=self.db_client, table_name=fq_loads_table_name, write_disposition=write_disposition, ) def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") @@ -691,48 +458,22 @@ def restore_file_load(self, file_path: str) -> LoadJob: def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: parent_table = table.get("parent") - return LoadLanceDBJob( - self.schema, - table, - file_path, - type_mapper=self.type_mapper, - db_client=self.db_client, - client_config=self.config, - model_func=self.model_func, - fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None - ), - ) - - def create_table_chain_completed_followup_jobs( - self, - table_chain: Sequence[TTableSchema], - completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + return LoadLanceDBJob(self.schema, table, file_path, type_mapper=self.type_mapper, db_client=self.db_client, client_config=self.config, model_func=self.model_func, fq_table_name=self.make_qualified_table_name( + table["name"]), fq_parent_table_name=(self.make_qualified_table_name(parent_table) if parent_table else None), ) + + def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: assert completed_table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs - ) + jobs = super().create_table_chain_completed_followup_jobs(table_chain, completed_table_chain_jobs) for table in table_chain: if table in self.schema.dlt_tables(): continue # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition") == "merge": + if table.get("write_disposition")=="merge": parent_table = table.get("parent") - jobs.append( - LanceDBRemoveOrphansJob( - db_client=self.db_client, - table_schema=self.prepare_load_table(table["name"]), - fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None - ), - client_config=self.config, - ) - ) + jobs.append(LanceDBRemoveOrphansJob(db_client=self.db_client, table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name( + table["name"]), fq_parent_table_name=(self.make_qualified_table_name(parent_table) if parent_table else None), client_config=self.config, )) return jobs @@ -743,18 +484,8 @@ def table_exists(self, table_name: str) -> bool: class LoadLanceDBJob(LoadJob, FollowupJob): arrow_schema: TArrowSchema - def __init__( - self, - schema: Schema, - table_schema: TTableSchema, - local_path: str, - type_mapper: LanceDBTypeMapper, - db_client: DBConnection, - client_config: LanceDBClientConfiguration, - model_func: TextEmbeddingFunction, - fq_table_name: str, - fq_parent_table_name: Optional[str], - ) -> None: + def __init__(self, schema: Schema, table_schema: TTableSchema, local_path: str, type_mapper: LanceDBTypeMapper, db_client: DBConnection, client_config: LanceDBClientConfiguration, model_func: TextEmbeddingFunction, fq_table_name: str, fq_parent_table_name: + Optional[str], ) -> None: file_name = FileStorage.get_file_name_from_file_path(local_path) super().__init__(file_name) self.schema: Schema = schema @@ -769,9 +500,7 @@ def __init__( self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name - self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition", "append") - ) + self.write_disposition: TWriteDisposition = cast(TWriteDisposition, self.table_schema.get("write_disposition", "append")) with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f, memory_map=True) @@ -779,11 +508,7 @@ def __init__( if self.table_schema not in self.schema.dlt_tables(): for record in records: # Add reserved ID fields. - uuid_id = ( - generate_uuid(record, self.unique_identifiers, self.fq_table_name) - if self.unique_identifiers - else str(uuid.uuid4()) - ) + uuid_id = (generate_uuid(record, self.unique_identifiers, self.fq_table_name) if self.unique_identifiers else str(uuid.uuid4())) record.update({self.id_field_name: uuid_id}) # LanceDB expects all fields in the target arrow table to be present in the data payload. @@ -792,13 +517,7 @@ def __init__( for field in missing_fields: record[field] = None - upload_batch( - records, - db_client=db_client, - table_name=self.fq_table_name, - write_disposition=self.write_disposition, - id_field_name=self.id_field_name, - ) + upload_batch(records, db_client=db_client, table_name=self.fq_table_name, write_disposition=self.write_disposition, id_field_name=self.id_field_name, ) def state(self) -> TLoadJobState: return "completed" @@ -808,34 +527,17 @@ def exception(self) -> str: class LanceDBRemoveOrphansJob(NewLoadJobImpl): - def __init__( - self, - db_client: DBConnection, - table_schema: TTableSchema, - client_config: LanceDBClientConfiguration, - fq_table_name: str, - fq_parent_table_name: Optional[str], - ) -> None: + def __init__(self, db_client: DBConnection, table_schema: TTableSchema, client_config: LanceDBClientConfiguration, fq_table_name: str, fq_parent_table_name: Optional[str], ) -> None: self.db_client = db_client self.table_schema: TTableSchema = table_schema self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition") - ) + self.write_disposition: TWriteDisposition = cast(TWriteDisposition, self.table_schema.get("write_disposition")) self.id_field_name: str = client_config.id_field_name - job_id = ParsedLoadJobFileName( - table_schema["name"], - ParsedLoadJobFileName.new_file_id(), - 0, - "parquet", - ).file_name() + job_id = ParsedLoadJobFileName(table_schema["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet", ).file_name() - super().__init__( - file_name=job_id, - status="running", - ) + super().__init__(file_name=job_id, status="running", ) self._save_text_file("") @@ -844,11 +546,9 @@ def __init__( def execute(self) -> None: orphaned_ids: Set[str] - if self.write_disposition != "merge": - raise DestinationTerminalException( - f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" - " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." - ) + if self.write_disposition!="merge": + raise DestinationTerminalException(f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" + " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED.") # Orphans are removed irrespective of which merge strategy is picked. try: @@ -858,9 +558,7 @@ def execute(self) -> None: parent_table = self.db_client.open_table(self.fq_parent_table_name) parent_table.checkout_latest() except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" - ) from e + raise DestinationTransientException("Couldn't open lancedb database. Orphan removal WILL BE RETRIED") from e try: if self.fq_parent_table_name: @@ -871,7 +569,7 @@ def execute(self) -> None: if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: + elif len(orphaned_ids)==1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: @@ -879,31 +577,22 @@ def execute(self) -> None: # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. - grouping_key = ( - get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) - or self.id_field_name - ) + grouping_key = (get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) or self.id_field_name) grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] - child_table_arrow: pa.Table = child_table.to_arrow().select( - [*grouping_key, "_dlt_load_id", "_dlt_id"] - ) + child_table_arrow: pa.Table = child_table.to_arrow().select([*grouping_key, "_dlt_load_id", "_dlt_id"]) - grouped = child_table_arrow.group_by(grouping_key).aggregate( - [("_dlt_load_id", "max")] - ) + grouped = child_table_arrow.group_by(grouping_key).aggregate([("_dlt_load_id", "max")]) joined = child_table_arrow.join(grouped, keys=grouping_key) orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() if len(orphaned_ids) > 1: child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: + elif len(orphaned_ids)==1: child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") except ArrowInvalid as e: - raise DestinationTerminalException( - "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." - ) from e + raise DestinationTerminalException("Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED.") from e def state(self) -> TLoadJobState: return "completed" From dc20a55986e6866993577d2825504e5fd9eb7b63 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 4 Aug 2024 23:12:28 +0200 Subject: [PATCH 031/113] Remove recommended file size from LanceDB destination capabilities Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index 9acb82344b..beab6a4031 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -30,8 +30,6 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.decimal_precision = (38, 18) caps.timestamp_precision = 6 - caps.recommended_file_size = 128_000_000 - return caps @property From 6ed540b0b4fbc6281abc66b727658aeeb3701d06 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 4 Aug 2024 23:26:14 +0200 Subject: [PATCH 032/113] Update LanceDB client to use more efficient batch processing methods on loading for Load Jobs Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 537 ++++++++++++++---- dlt/destinations/impl/lancedb/utils.py | 36 +- 2 files changed, 447 insertions(+), 126 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 04f0e2cac0..1294155c00 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,6 +1,18 @@ -import uuid from types import TracebackType -from typing import (List, Any, cast, Union, Tuple, Iterable, Type, Optional, Dict, Sequence, TYPE_CHECKING, Set, ) +from typing import ( + List, + Any, + cast, + Union, + Tuple, + Iterable, + Type, + Optional, + Dict, + Sequence, + TYPE_CHECKING, + Set, +) import lancedb # type: ignore import pyarrow as pa @@ -15,23 +27,54 @@ from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.exceptions import (DestinationUndefinedEntity, DestinationTransientException, - DestinationTerminalException, ) -from dlt.common.destination.reference import (JobClientBase, WithStateSync, LoadJob, StorageSchemaInfo, StateInfo, - TLoadJobState, NewLoadJob, FollowupJob, ) +from dlt.common.destination.exceptions import ( + DestinationUndefinedEntity, + DestinationTransientException, + DestinationTerminalException, +) +from dlt.common.destination.reference import ( + JobClientBase, + WithStateSync, + LoadJob, + StorageSchemaInfo, + StateInfo, + TLoadJobState, + NewLoadJob, + FollowupJob, +) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.typing import (TColumnType, TTableFormat, TTableSchemaColumns, TWriteDisposition, ) +from dlt.common.schema.typing import ( + TColumnType, + TTableFormat, + TTableSchemaColumns, + TWriteDisposition, +) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny -from dlt.destinations.impl.lancedb.configuration import (LanceDBClientConfiguration, ) -from dlt.destinations.impl.lancedb.exceptions import (lancedb_error, ) -from dlt.destinations.impl.lancedb.lancedb_adapter import (VECTORIZE_HINT, DOCUMENT_ID_HINT, ) -from dlt.destinations.impl.lancedb.schema import (make_arrow_field_schema, make_arrow_table_schema, TArrowSchema, - NULL_SCHEMA, TArrowField, ) -from dlt.destinations.impl.lancedb.utils import (list_merge_identifiers, generate_uuid, - set_non_standard_providers_environment_variables, ) +from dlt.destinations.impl.lancedb.configuration import ( + LanceDBClientConfiguration, +) +from dlt.destinations.impl.lancedb.exceptions import ( + lancedb_error, +) +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + VECTORIZE_HINT, + DOCUMENT_ID_HINT, +) +from dlt.destinations.impl.lancedb.schema import ( + make_arrow_field_schema, + make_arrow_table_schema, + TArrowSchema, + NULL_SCHEMA, + TArrowField, +) +from dlt.destinations.impl.lancedb.utils import ( + list_merge_identifiers, + set_non_standard_providers_environment_variables, + generate_arrow_uuid_column, +) from dlt.destinations.job_impl import EmptyLoadJob, NewLoadJobImpl from dlt.destinations.type_mapping import TypeMapper @@ -45,38 +88,80 @@ class LanceDBTypeMapper(TypeMapper): - sct_to_unbound_dbt = {"text": pa.string(), "double": pa.float64(), "bool": pa.bool_(), "bigint": pa.int64(), "binary": pa.binary(), "date": pa.date32(), "complex": pa.string(), } + sct_to_unbound_dbt = { + "text": pa.string(), + "double": pa.float64(), + "bool": pa.bool_(), + "bigint": pa.int64(), + "binary": pa.binary(), + "date": pa.date32(), + "complex": pa.string(), + } sct_to_dbt = {} - dbt_to_sct = {pa.string(): "text", pa.float64(): "double", pa.bool_(): "bool", pa.int64(): "bigint", pa.binary(): "binary", pa.date32(): "date", } - - def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> pa.Decimal128Type: + dbt_to_sct = { + pa.string(): "text", + pa.float64(): "double", + pa.bool_(): "bool", + pa.int64(): "bigint", + pa.binary(): "binary", + pa.date32(): "date", + } + + def to_db_decimal_type( + self, precision: Optional[int], scale: Optional[int] + ) -> pa.Decimal128Type: precision, scale = self.decimal_precision(precision, scale) return pa.decimal128(precision, scale) - def to_db_datetime_type(self, precision: Optional[int], table_format: TTableFormat = None) -> pa.TimestampType: + def to_db_datetime_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> pa.TimestampType: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.timestamp(unit, "UTC") - def to_db_time_type(self, precision: Optional[int], table_format: TTableFormat = None) -> pa.Time64Type: + def to_db_time_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> pa.Time64Type: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.time64(unit) - def from_db_type(self, db_type: pa.DataType, precision: Optional[int] = None, scale: Optional[int] = None, ) -> TColumnType: + def from_db_type( + self, + db_type: pa.DataType, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> TColumnType: if isinstance(db_type, pa.TimestampType): - return dict(data_type="timestamp", precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], scale=scale, ) + return dict( + data_type="timestamp", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) if isinstance(db_type, pa.Time64Type): - return dict(data_type="time", precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], scale=scale, ) + return dict( + data_type="time", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale)==self.capabilities.wei_precision: + if (precision, scale) == self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) -def upload_batch(records: List[DictStrAny], /, *, db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, ) -> None: +def write_to_db( + records: Union[pa.Table, List[DictStrAny]], + /, + *, + db_client: DBConnection, + table_name: str, + write_disposition: Optional[TWriteDisposition] = "append", + id_field_name: Optional[str] = None, +) -> None: """Inserts records into a LanceDB table with automatic embedding computation. Args: @@ -95,22 +180,30 @@ def upload_batch(records: List[DictStrAny], /, *, db_client: DBConnection, table tbl = db_client.open_table(table_name) tbl.checkout_latest() except FileNotFoundError as e: - raise DestinationTransientException("Couldn't open lancedb database. Batch WILL BE RETRIED") from e + raise DestinationTransientException( + "Couldn't open lancedb database. Batch WILL BE RETRIED" + ) from e try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition=="replace": + elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition=="merge": + elif write_disposition == "merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") - tbl.merge_insert(id_field_name).when_matched_update_all().when_not_matched_insert_all().execute(records) + tbl.merge_insert( + id_field_name + ).when_matched_update_all().when_not_matched_insert_all().execute(records) else: - raise DestinationTerminalException(f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" - " failed AND WILL **NOT** BE RETRIED.") + raise DestinationTerminalException( + f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" + " failed AND WILL **NOT** BE RETRIED." + ) except ArrowInvalid as e: - raise DestinationTerminalException("Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED.") from e + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e class LanceDBClient(JobClientBase, WithStateSync): @@ -118,10 +211,19 @@ class LanceDBClient(JobClientBase, WithStateSync): model_func: TextEmbeddingFunction - def __init__(self, schema: Schema, config: LanceDBClientConfiguration, capabilities: DestinationCapabilitiesContext, ) -> None: + def __init__( + self, + schema: Schema, + config: LanceDBClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: super().__init__(schema, config, capabilities) self.config: LanceDBClientConfiguration = config - self.db_client: DBConnection = lancedb.connect(uri=self.config.credentials.uri, api_key=self.config.credentials.api_key, read_consistency_interval=timedelta(0), ) + self.db_client: DBConnection = lancedb.connect( + uri=self.config.credentials.uri, + api_key=self.config.credentials.api_key, + read_consistency_interval=timedelta(0), + ) self.registry = EmbeddingFunctionRegistry.get_instance() self.type_mapper = LanceDBTypeMapper(self.capabilities) self.sentinel_table_name = config.sentinel_table_name @@ -131,14 +233,24 @@ def __init__(self, schema: Schema, config: LanceDBClientConfiguration, capabilit # LanceDB doesn't provide a standardized way to set API keys across providers. # Some use ENV variables and others allow passing api key as an argument. # To account for this, we set provider environment variable as well. - set_non_standard_providers_environment_variables(embedding_model_provider, self.config.credentials.embedding_model_provider_api_key, ) + set_non_standard_providers_environment_variables( + embedding_model_provider, + self.config.credentials.embedding_model_provider_api_key, + ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider=="openai": + if embedding_model_provider == "openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings - self.model_func = PatchedOpenAIEmbeddings(max_retries=self.config.options.max_retries, api_key=self.config.credentials.api_key, ) + self.model_func = PatchedOpenAIEmbeddings( + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) else: - self.model_func = self.registry.get(embedding_model_provider).create(name=self.config.embedding_model, max_retries=self.config.options.max_retries, api_key=self.config.credentials.api_key, ) + self.model_func = self.registry.get(embedding_model_provider).create( + name=self.config.embedding_model, + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) self.vector_field_name = self.config.vector_field_name self.id_field_name = self.config.id_field_name @@ -152,13 +264,20 @@ def sentinel_table(self) -> str: return self.make_qualified_table_name(self.sentinel_table_name) def make_qualified_table_name(self, table_name: str) -> str: - return (f"{self.dataset_name}{self.config.dataset_separator}{table_name}" if self.dataset_name else table_name) + return ( + f"{self.dataset_name}{self.config.dataset_separator}{table_name}" + if self.dataset_name + else table_name + ) def get_table_schema(self, table_name: str) -> TArrowSchema: schema_table: Table = self.db_client.open_table(table_name) schema_table.checkout_latest() schema = schema_table.schema - return cast(TArrowSchema, schema, ) + return cast( + TArrowSchema, + schema, + ) @lancedb_error def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: @@ -181,7 +300,11 @@ def delete_table(self, table_name: str) -> None: """ self.db_client.drop_table(table_name) - def query_table(self, table_name: str, query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, ) -> LanceQueryBuilder: + def query_table( + self, + table_name: str, + query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, + ) -> LanceQueryBuilder: """Query a LanceDB table. Args: @@ -200,11 +323,15 @@ def _get_table_names(self) -> List[str]: """Return all tables in the dataset, excluding the sentinel table.""" if self.dataset_name: prefix = f"{self.dataset_name}{self.config.dataset_separator}" - table_names = [table_name for table_name in self.db_client.table_names() if table_name.startswith(prefix)] + table_names = [ + table_name + for table_name in self.db_client.table_names() + if table_name.startswith(prefix) + ] else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name!=self.sentinel_table] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -230,7 +357,10 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: continue schema = self.get_table_schema(fq_table_name) self.db_client.drop_table(fq_table_name) - self.create_table(table_name=fq_table_name, schema=schema, ) + self.create_table( + table_name=fq_table_name, + schema=schema, + ) @lancedb_error def is_storage_initialized(self) -> bool: @@ -245,7 +375,11 @@ def _delete_sentinel_table(self) -> None: self.db_client.drop_table(self.sentinel_table) @lancedb_error - def update_stored_schema(self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None, ) -> Optional[TSchemaTables]: + def update_stored_schema( + self, + only_tables: Iterable[str] = None, + expected_update: TSchemaTables = None, + ) -> Optional[TSchemaTables]: super().update_stored_schema(only_tables, expected_update) applied_update: TSchemaTables = {} @@ -255,13 +389,17 @@ def update_stored_schema(self, only_tables: Iterable[str] = None, expected_updat schema_info = None if schema_info is None: - logger.info(f"Schema with hash {self.schema.stored_version_hash} " - "not found in the storage. upgrading") + logger.info( + f"Schema with hash {self.schema.stored_version_hash} " + "not found in the storage. upgrading" + ) self._execute_schema_update(only_tables) else: - logger.info(f"Schema with hash {self.schema.stored_version_hash} " - f"inserted at {schema_info.inserted_at} found " - "in storage, no upgrade required") + logger.info( + f"Schema with hash {self.schema.stored_version_hash} " + f"inserted at {schema_info.inserted_at} found " + "in storage, no upgrade required" + ) return applied_update def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: @@ -279,11 +417,16 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] field: TArrowField for field in arrow_schema: name = self.schema.naming.normalize_identifier(field.name) - table_schema[name] = {"name": name, **self.type_mapper.from_db_type(field.type), } + table_schema[name] = { + "name": name, + **self.type_mapper.from_db_type(field.type), + } return True, table_schema @lancedb_error - def add_table_fields(self, table_name: str, field_schemas: List[TArrowField]) -> Optional[Table]: + def add_table_fields( + self, table_name: str, field_schemas: List[TArrowField] + ) -> Optional[Table]: """Add multiple fields to the LanceDB table at once. Args: @@ -316,16 +459,25 @@ def add_table_fields(self, table_name: str, field_schemas: List[TArrowField]) -> def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns(table_name, existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) + new_columns = self.schema.get_new_table_columns( + table_name, + existing_columns, + self.capabilities.generates_case_sensitive_identifiers(), + ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: - field_schemas: List[TArrowField] = [make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns] + field_schemas: List[TArrowField] = [ + make_arrow_field_schema(column["name"], column, self.type_mapper) + for column in new_columns + ] fq_table_name = self.make_qualified_table_name(table_name) self.add_table_fields(fq_table_name, field_schemas) else: if table_name not in self.schema.dlt_table_names(): - embedding_fields = get_columns_names_with_prop(self.schema.get_table(table_name=table_name), VECTORIZE_HINT) + embedding_fields = get_columns_names_with_prop( + self.schema.get_table(table_name=table_name), VECTORIZE_HINT + ) vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func @@ -337,7 +489,16 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: embedding_model_func = None embedding_model_dimensions = None - table_schema: TArrowSchema = make_arrow_table_schema(table_name, schema=self.schema, type_mapper=self.type_mapper, embedding_fields=embedding_fields, embedding_model_func=embedding_model_func, embedding_model_dimensions=embedding_model_dimensions, vector_field_name=vector_field_name, id_field_name=id_field_name, ) + table_schema: TArrowSchema = make_arrow_table_schema( + table_name, + schema=self.schema, + type_mapper=self.type_mapper, + embedding_fields=embedding_fields, + embedding_model_func=embedding_model_func, + embedding_model_dimensions=embedding_model_dimensions, + vector_field_name=vector_field_name, + id_field_name=id_field_name, + ) fq_table_name = self.make_qualified_table_name(table_name) self.create_table(fq_table_name, table_schema) @@ -345,13 +506,33 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: @lancedb_error def update_schema_in_storage(self) -> None: - records = [{self.schema.naming.normalize_identifier("version"): self.schema.version, self.schema.naming.normalize_identifier("engine_version"): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), self.schema.naming.normalize_identifier("schema_name"): self.schema.name, - self.schema.naming.normalize_identifier("version_hash"): self.schema.stored_version_hash, self.schema.naming.normalize_identifier("schema"): json.dumps(self.schema.to_dict()), }] + records = [ + { + self.schema.naming.normalize_identifier("version"): self.schema.version, + self.schema.naming.normalize_identifier( + "engine_version" + ): self.schema.ENGINE_VERSION, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier( + "version_hash" + ): self.schema.stored_version_hash, + self.schema.naming.normalize_identifier("schema"): json.dumps( + self.schema.to_dict() + ), + } + ] fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - write_disposition = self.schema.get_table(self.schema.version_table_name).get("write_disposition") + write_disposition = self.schema.get_table(self.schema.version_table_name).get( + "write_disposition" + ) - upload_batch(records, db_client=self.db_client, table_name=fq_version_table_name, write_disposition=write_disposition, ) + write_to_db( + records, + db_client=self.db_client, + table_name=fq_version_table_name, + write_disposition=write_disposition, + ) @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: @@ -378,18 +559,31 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less # data into memory as possible. - state_table = (state_table_.search().where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True).to_arrow()) + state_table = ( + state_table_.search() + .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) + .to_arrow() + ) loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() # Join arrow tables in-memory. - joined_table: pa.Table = state_table.join(loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner").sort_by([(p_dlt_load_id, "descending")]) + joined_table: pa.Table = state_table.join( + loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" + ).sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows==0: + if joined_table.num_rows == 0: return None state = joined_table.take([0]).to_pylist()[0] - return StateInfo(version=state[p_version], engine_version=state[p_engine_version], pipeline_name=state[p_pipeline_name], state=state[p_state], created_at=pendulum.instance( - state[p_created_at]), version_hash=state[p_version_hash], _dlt_load_id=state[p_dlt_load_id], ) + return StateInfo( + version=state[p_version], + engine_version=state[p_engine_version], + pipeline_name=state[p_pipeline_name], + state=state[p_state], + created_at=pendulum.instance(state[p_created_at]), + version_hash=state[p_version_hash], + _dlt_load_id=state[p_dlt_load_id], + ) @lancedb_error def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: @@ -405,11 +599,21 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = (version_table.search().where(f'`{p_version_hash}` = "{schema_hash}"', prefilter=True)).to_list() + schemas = ( + version_table.search().where( + f'`{p_version_hash}` = "{schema_hash}"', prefilter=True + ) + ).to_list() most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] - return StorageSchemaInfo(version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], version=most_recent_schema[p_version], engine_version= - most_recent_schema[p_engine_version], inserted_at=most_recent_schema[p_inserted_at], schema=most_recent_schema[p_schema], ) + return StorageSchemaInfo( + version_hash=most_recent_schema[p_version_hash], + schema_name=most_recent_schema[p_schema_name], + version=most_recent_schema[p_version], + engine_version=most_recent_schema[p_engine_version], + inserted_at=most_recent_schema[p_inserted_at], + schema=most_recent_schema[p_schema], + ) except IndexError: return None @@ -428,15 +632,30 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = (version_table.search().where(f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True)).to_list() + schemas = ( + version_table.search().where( + f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True + ) + ).to_list() most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] - return StorageSchemaInfo(version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], version=most_recent_schema[p_version], engine_version= - most_recent_schema[p_engine_version], inserted_at=most_recent_schema[p_inserted_at], schema=most_recent_schema[p_schema], ) + return StorageSchemaInfo( + version_hash=most_recent_schema[p_version_hash], + schema_name=most_recent_schema[p_schema_name], + version=most_recent_schema[p_version], + engine_version=most_recent_schema[p_engine_version], + inserted_at=most_recent_schema[p_inserted_at], + schema=most_recent_schema[p_schema], + ) except IndexError: return None - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType, ) -> None: + def __exit__( + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: pass def __enter__(self) -> "LanceDBClient": @@ -444,13 +663,27 @@ def __enter__(self) -> "LanceDBClient": @lancedb_error def complete_load(self, load_id: str) -> None: - records = [{self.schema.naming.normalize_identifier("load_id"): load_id, self.schema.naming.normalize_identifier("schema_name"): self.schema.name, - self.schema.naming.normalize_identifier("status"): 0, self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier("schema_version_hash"): None, # Payload schema must match the target schema. - }] + records = [ + { + self.schema.naming.normalize_identifier("load_id"): load_id, + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("status"): 0, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier( + "schema_version_hash" + ): None, # Payload schema must match the target schema. + } + ] fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) - write_disposition = self.schema.get_table(self.schema.loads_table_name).get("write_disposition") - upload_batch(records, db_client=self.db_client, table_name=fq_loads_table_name, write_disposition=write_disposition, ) + write_disposition = self.schema.get_table(self.schema.loads_table_name).get( + "write_disposition" + ) + write_to_db( + records, + db_client=self.db_client, + table_name=fq_loads_table_name, + write_disposition=write_disposition, + ) def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") @@ -458,22 +691,48 @@ def restore_file_load(self, file_path: str) -> LoadJob: def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: parent_table = table.get("parent") - return LoadLanceDBJob(self.schema, table, file_path, type_mapper=self.type_mapper, db_client=self.db_client, client_config=self.config, model_func=self.model_func, fq_table_name=self.make_qualified_table_name( - table["name"]), fq_parent_table_name=(self.make_qualified_table_name(parent_table) if parent_table else None), ) - - def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: + return LoadLanceDBJob( + self.schema, + table, + file_path, + type_mapper=self.type_mapper, + db_client=self.db_client, + client_config=self.config, + model_func=self.model_func, + fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) if parent_table else None + ), + ) + + def create_table_chain_completed_followup_jobs( + self, + table_chain: Sequence[TTableSchema], + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + ) -> List[NewLoadJob]: assert completed_table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs(table_chain, completed_table_chain_jobs) + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) for table in table_chain: if table in self.schema.dlt_tables(): continue # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition")=="merge": + if table.get("write_disposition") == "merge": parent_table = table.get("parent") - jobs.append(LanceDBRemoveOrphansJob(db_client=self.db_client, table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name( - table["name"]), fq_parent_table_name=(self.make_qualified_table_name(parent_table) if parent_table else None), client_config=self.config, )) + jobs.append( + LanceDBRemoveOrphansJob( + db_client=self.db_client, + table_schema=self.prepare_load_table(table["name"]), + fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) if parent_table else None + ), + client_config=self.config, + ) + ) return jobs @@ -484,8 +743,18 @@ def table_exists(self, table_name: str) -> bool: class LoadLanceDBJob(LoadJob, FollowupJob): arrow_schema: TArrowSchema - def __init__(self, schema: Schema, table_schema: TTableSchema, local_path: str, type_mapper: LanceDBTypeMapper, db_client: DBConnection, client_config: LanceDBClientConfiguration, model_func: TextEmbeddingFunction, fq_table_name: str, fq_parent_table_name: - Optional[str], ) -> None: + def __init__( + self, + schema: Schema, + table_schema: TTableSchema, + local_path: str, + type_mapper: LanceDBTypeMapper, + db_client: DBConnection, + client_config: LanceDBClientConfiguration, + model_func: TextEmbeddingFunction, + fq_table_name: str, + fq_parent_table_name: Optional[str], + ) -> None: file_name = FileStorage.get_file_name_from_file_path(local_path) super().__init__(file_name) self.schema: Schema = schema @@ -500,24 +769,28 @@ def __init__(self, schema: Schema, table_schema: TTableSchema, local_path: str, self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name - self.write_disposition: TWriteDisposition = cast(TWriteDisposition, self.table_schema.get("write_disposition", "append")) + self.write_disposition: TWriteDisposition = cast( + TWriteDisposition, self.table_schema.get("write_disposition", "append") + ) with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: - arrow_table: pa.Table = pq.read_table(f, memory_map=True) + arrow_table: pa.Table = pq.read_table(f) if self.table_schema not in self.schema.dlt_tables(): - for record in records: - # Add reserved ID fields. - uuid_id = (generate_uuid(record, self.unique_identifiers, self.fq_table_name) if self.unique_identifiers else str(uuid.uuid4())) - record.update({self.id_field_name: uuid_id}) - - # LanceDB expects all fields in the target arrow table to be present in the data payload. - # We add and set these missing fields, that are fields not present in the target schema, to NULL. - missing_fields = set(self.table_schema["columns"]) - set(record) - for field in missing_fields: - record[field] = None - - upload_batch(records, db_client=db_client, table_name=self.fq_table_name, write_disposition=self.write_disposition, id_field_name=self.id_field_name, ) + arrow_table = generate_arrow_uuid_column( + arrow_table, + unique_identifiers=self.unique_identifiers, + table_name=self.fq_table_name, + id_field_name=self.id_field_name, + ) + + write_to_db( + arrow_table, + db_client=db_client, + table_name=self.fq_table_name, + write_disposition=self.write_disposition, + id_field_name=self.id_field_name, + ) def state(self) -> TLoadJobState: return "completed" @@ -527,17 +800,34 @@ def exception(self) -> str: class LanceDBRemoveOrphansJob(NewLoadJobImpl): - def __init__(self, db_client: DBConnection, table_schema: TTableSchema, client_config: LanceDBClientConfiguration, fq_table_name: str, fq_parent_table_name: Optional[str], ) -> None: + def __init__( + self, + db_client: DBConnection, + table_schema: TTableSchema, + client_config: LanceDBClientConfiguration, + fq_table_name: str, + fq_parent_table_name: Optional[str], + ) -> None: self.db_client = db_client self.table_schema: TTableSchema = table_schema self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.write_disposition: TWriteDisposition = cast(TWriteDisposition, self.table_schema.get("write_disposition")) + self.write_disposition: TWriteDisposition = cast( + TWriteDisposition, self.table_schema.get("write_disposition") + ) self.id_field_name: str = client_config.id_field_name - job_id = ParsedLoadJobFileName(table_schema["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet", ).file_name() + job_id = ParsedLoadJobFileName( + table_schema["name"], + ParsedLoadJobFileName.new_file_id(), + 0, + "parquet", + ).file_name() - super().__init__(file_name=job_id, status="running", ) + super().__init__( + file_name=job_id, + status="running", + ) self._save_text_file("") @@ -546,9 +836,11 @@ def __init__(self, db_client: DBConnection, table_schema: TTableSchema, client_c def execute(self) -> None: orphaned_ids: Set[str] - if self.write_disposition!="merge": - raise DestinationTerminalException(f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" - " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED.") + if self.write_disposition != "merge": + raise DestinationTerminalException( + f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" + " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." + ) # Orphans are removed irrespective of which merge strategy is picked. try: @@ -558,7 +850,9 @@ def execute(self) -> None: parent_table = self.db_client.open_table(self.fq_parent_table_name) parent_table.checkout_latest() except FileNotFoundError as e: - raise DestinationTransientException("Couldn't open lancedb database. Orphan removal WILL BE RETRIED") from e + raise DestinationTransientException( + "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" + ) from e try: if self.fq_parent_table_name: @@ -569,7 +863,7 @@ def execute(self) -> None: if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids)==1: + elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: @@ -577,22 +871,31 @@ def execute(self) -> None: # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. - grouping_key = (get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) or self.id_field_name) + grouping_key = ( + get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) + or self.id_field_name + ) grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] - child_table_arrow: pa.Table = child_table.to_arrow().select([*grouping_key, "_dlt_load_id", "_dlt_id"]) + child_table_arrow: pa.Table = child_table.to_arrow().select( + [*grouping_key, "_dlt_load_id", "_dlt_id"] + ) - grouped = child_table_arrow.group_by(grouping_key).aggregate([("_dlt_load_id", "max")]) + grouped = child_table_arrow.group_by(grouping_key).aggregate( + [("_dlt_load_id", "max")] + ) joined = child_table_arrow.join(grouped, keys=grouping_key) orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() if len(orphaned_ids) > 1: child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids)==1: + elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") except ArrowInvalid as e: - raise DestinationTerminalException("Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED.") from e + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e def state(self) -> TLoadJobState: return "completed" diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index f202903598..0edb8487c5 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -2,9 +2,11 @@ import uuid from typing import Sequence, Union, Dict, List +import pyarrow as pa +import pyarrow.compute as pc + from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider @@ -16,19 +18,35 @@ } -def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: - """Generates deterministic UUID - used for deduplication. +# TODO: Update `generate_arrow_uuid_column` when pyarrow 17.0.0 becomes available with vectorized operations (batched + memory-mapped) +def generate_arrow_uuid_column( + table: pa.Table, unique_identifiers: List[str], id_field_name: str, table_name: str +) -> pa.Table: + """Generates deterministic UUID - used for deduplication, returning a new arrow + table with added UUID column. Args: - data (Dict[str, Any]): Arbitrary data to generate UUID for. - unique_identifiers (Sequence[str]): A list of unique identifiers. - table_name (str): LanceDB table name. + table (pa.Table): PyArrow table to generate UUIDs for. + unique_identifiers (List[str]): A list of unique identifier column names. + id_field_name (str): Name of the new UUID column. + table_name (str): Name of the table. Returns: - str: A string representation of the generated UUID. + pa.Table: New PyArrow table with the new UUID column. """ - data_id = "_".join(str(data[key]) for key in unique_identifiers) - return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) + + string_columns = [] + for col in unique_identifiers: + column = table[col] + column = pc.cast(column, pa.string()) + column = pc.fill_null(column, "") + string_columns.append(column.to_pylist()) + + concat_values = ["".join(x) for x in zip(*string_columns)] + uuids = [str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) for x in concat_values] + uuid_column = pa.array(uuids) + + return table.append_column(id_field_name, uuid_column) def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: From 0a9682f972f66bb7e45515d89e6765c001acf181 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 17:41:58 +0200 Subject: [PATCH 033/113] Refactor unique identifier handling for LanceDB tables Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 6 +++--- dlt/destinations/impl/lancedb/utils.py | 21 ++++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 1294155c00..e337936541 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -71,7 +71,7 @@ TArrowField, ) from dlt.destinations.impl.lancedb.utils import ( - list_merge_identifiers, + get_unique_identifiers_from_table_schema, set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) @@ -764,7 +764,7 @@ def __init__( self.table_name: str = table_schema["name"] self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.unique_identifiers: List[str] = list_merge_identifiers(table_schema) + self.unique_identifiers: List[str] = get_unique_identifiers_from_table_schema(table_schema) self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions @@ -776,7 +776,7 @@ def __init__( with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - if self.table_schema not in self.schema.dlt_tables(): + if self.table_schema['name'] not in self.schema.dlt_table_names(): arrow_table = generate_arrow_uuid_column( arrow_table, unique_identifiers=self.unique_identifiers, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 0edb8487c5..9e9d05c936 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -35,21 +35,20 @@ def generate_arrow_uuid_column( pa.Table: New PyArrow table with the new UUID column. """ - string_columns = [] + unique_identifiers_columns = [] for col in unique_identifiers: column = table[col] column = pc.cast(column, pa.string()) column = pc.fill_null(column, "") - string_columns.append(column.to_pylist()) + unique_identifiers_columns.append(column.to_pylist()) - concat_values = ["".join(x) for x in zip(*string_columns)] - uuids = [str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) for x in concat_values] + concatenated_ids = ["".join(x) for x in zip(*unique_identifiers_columns)] + uuids = [str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) for x in concatenated_ids] uuid_column = pa.array(uuids) - return table.append_column(id_field_name, uuid_column) -def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: +def get_unique_identifiers_from_table_schema(table_schema: TTableSchema) -> List[str]: """Returns a list of merge keys for a table used for either merging or deduplication. Args: @@ -58,12 +57,14 @@ def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: Returns: Sequence[str]: A list of unique column identifiers. """ + primary_keys = get_columns_names_with_prop(table_schema, "primary_key") + merge_keys = [] if table_schema.get("write_disposition") == "merge": - primary_keys = get_columns_names_with_prop(table_schema, "primary_key") merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys)): - return join_keys - return get_columns_names_with_prop(table_schema, "unique") + if join_keys := list(set(primary_keys + merge_keys)): + return join_keys + else: + return get_columns_names_with_prop(table_schema, "unique") def set_non_standard_providers_environment_variables( From a99224a6a3b6315ae143aa68cccea3d47a97c9be Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 18:10:03 +0200 Subject: [PATCH 034/113] Optimize UUID column generation for LanceDB tables Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 9e9d05c936..9091d2db1e 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -37,15 +37,17 @@ def generate_arrow_uuid_column( unique_identifiers_columns = [] for col in unique_identifiers: - column = table[col] - column = pc.cast(column, pa.string()) - column = pc.fill_null(column, "") + column = pc.fill_null(pc.cast(table[col], pa.string()), "") unique_identifiers_columns.append(column.to_pylist()) - concatenated_ids = ["".join(x) for x in zip(*unique_identifiers_columns)] - uuids = [str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) for x in concatenated_ids] - uuid_column = pa.array(uuids) - return table.append_column(id_field_name, uuid_column) + uuids = pa.array( + [ + str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) + for x in ["".join(x) for x in zip(*unique_identifiers_columns)] + ] + ) + + return table.append_column(id_field_name, uuids) def get_unique_identifiers_from_table_schema(table_schema: TTableSchema) -> List[str]: From 895331b81f73f4dc7c5112390a97af866734de3c Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 20:04:26 +0200 Subject: [PATCH 035/113] Refactor LanceDBClient to use string type hints for Table Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index e337936541..0ad8c9cdfb 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -21,9 +21,9 @@ from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore -from lancedb.table import Table # type: ignore +import lancedb.table # type: ignore from numpy import ndarray -from pyarrow import Array, ChunkedArray, ArrowInvalid, Table +from pyarrow import Array, ChunkedArray, ArrowInvalid from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext @@ -271,7 +271,7 @@ def make_qualified_table_name(self, table_name: str) -> str: ) def get_table_schema(self, table_name: str) -> TArrowSchema: - schema_table: Table = self.db_client.open_table(table_name) + schema_table: "lancedb.table.Table" = self.db_client.open_table(table_name) schema_table.checkout_latest() schema = schema_table.schema return cast( @@ -280,7 +280,9 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + def create_table( + self, table_name: str, schema: TArrowSchema, mode: str = "create" + ) -> "lancedb.table.Table": """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -314,7 +316,7 @@ def query_table( Returns: A LanceDB query builder. """ - query_table: Table = self.db_client.open_table(table_name) + query_table: "lancedb.table.Table" = self.db_client.open_table(table_name) query_table.checkout_latest() return query_table.search(query=query) @@ -366,7 +368,7 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: def is_storage_initialized(self) -> bool: return self.table_exists(self.sentinel_table) - def _create_sentinel_table(self) -> Table: + def _create_sentinel_table(self) -> "lancedb.table.Table": """Create an empty table to indicate that the storage is initialized.""" return self.create_table(schema=NULL_SCHEMA, table_name=self.sentinel_table) @@ -408,7 +410,7 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] try: fq_table_name = self.make_qualified_table_name(table_name) - table: Table = self.db_client.open_table(fq_table_name) + table: "lancedb.table.Table" = self.db_client.open_table(fq_table_name) table.checkout_latest() arrow_schema: TArrowSchema = table.schema except FileNotFoundError: @@ -425,15 +427,15 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] @lancedb_error def add_table_fields( - self, table_name: str, field_schemas: List[TArrowField] - ) -> Optional[Table]: + self, table_name: str, field_schemas: List[pa.Field] + ) -> Optional["lancedb.table.Table"]: """Add multiple fields to the LanceDB table at once. Args: - table_name: The name of the table to create the fields on. - field_schemas: The list of fields to create. + table_name: The name of the table to create the fields on. + field_schemas: The list of PyArrow Fields to create. """ - table: Table = self.db_client.open_table(table_name) + table: "lancedb.table.Table" = self.db_client.open_table(table_name) table.checkout_latest() arrow_table = table.to_arrow() @@ -540,10 +542,10 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) - state_table_: Table = self.db_client.open_table(fq_state_table_name) + state_table_: "lancedb.table.Table" = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() - loads_table_: Table = self.db_client.open_table(fq_loads_table_name) + loads_table_: "lancedb.table.Table" = self.db_client.open_table(fq_loads_table_name) loads_table_.checkout_latest() # normalize property names @@ -589,7 +591,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table: "lancedb.table.Table" = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() p_version_hash = self.schema.naming.normalize_identifier("version_hash") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") @@ -622,7 +624,7 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table: "lancedb.table.Table" = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() p_version_hash = self.schema.naming.normalize_identifier("version_hash") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") @@ -776,7 +778,7 @@ def __init__( with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - if self.table_schema['name'] not in self.schema.dlt_table_names(): + if self.table_schema["name"] not in self.schema.dlt_table_names(): arrow_table = generate_arrow_uuid_column( arrow_table, unique_identifiers=self.unique_identifiers, From a881e7a0571f4764a2212be251ce752d205dcd5d Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 20:07:59 +0200 Subject: [PATCH 036/113] Minor refactor Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 0ad8c9cdfb..278e01d0d1 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -48,7 +48,7 @@ TColumnType, TTableFormat, TTableSchemaColumns, - TWriteDisposition, + TWriteDisposition, TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName @@ -461,13 +461,13 @@ def add_table_fields( def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns( + new_columns: List[TColumnSchema] = self.schema.get_new_table_columns( table_name, existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") - if len(new_columns) > 0: + if new_columns: if exists: field_schemas: List[TArrowField] = [ make_arrow_field_schema(column["name"], column, self.type_mapper) From 7f245e2534c03e23e08db359be45125880bb881a Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 23:21:21 +0200 Subject: [PATCH 037/113] Implement efficient schema update with Nullability support Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 43 ++++++++++--------- dlt/destinations/impl/lancedb/schema.py | 21 ++++++++- dlt/destinations/impl/lancedb/utils.py | 2 + 3 files changed, 44 insertions(+), 22 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 278e01d0d1..9554d12a94 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -15,13 +15,13 @@ ) import lancedb # type: ignore +import lancedb.table # type: ignore import pyarrow as pa import pyarrow.compute as pc import pyarrow.parquet as pq from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore -import lancedb.table # type: ignore from numpy import ndarray from pyarrow import Array, ChunkedArray, ArrowInvalid @@ -48,7 +48,8 @@ TColumnType, TTableFormat, TTableSchemaColumns, - TWriteDisposition, TColumnSchema, + TWriteDisposition, + TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName @@ -69,6 +70,7 @@ TArrowSchema, NULL_SCHEMA, TArrowField, + arrow_datatype_to_fusion_datatype, ) from dlt.destinations.impl.lancedb.utils import ( get_unique_identifiers_from_table_schema, @@ -426,34 +428,33 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] return True, table_schema @lancedb_error - def add_table_fields( - self, table_name: str, field_schemas: List[pa.Field] - ) -> Optional["lancedb.table.Table"]: - """Add multiple fields to the LanceDB table at once. + def extend_lancedb_table_schema(self, table_name: str, field_schemas: List[pa.Field]) -> None: + """Extend LanceDB table schema with empty columns. Args: table_name: The name of the table to create the fields on. - field_schemas: The list of PyArrow Fields to create. + field_schemas: The list of PyArrow Fields to create in the target LanceDB table. """ table: "lancedb.table.Table" = self.db_client.open_table(table_name) table.checkout_latest() - arrow_table = table.to_arrow() - # Check if any of the new fields already exist in the table. - existing_fields = set(arrow_table.schema.names) - new_fields = [field for field in field_schemas if field.name not in existing_fields] - - if not new_fields: - # All fields already present, skip. - return None + try: + # Use DataFusion SQL syntax to alter fields without loading data into client memory. + # Currently, the most efficient way to modify column values is in LanceDB. + new_fields = { + field.name: f"CAST(NULL AS {arrow_datatype_to_fusion_datatype(field.type)})" + for field in field_schemas + } + table.add_columns(new_fields) - null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + # Make new columns nullable in the Arrow schema. + # Necessary because the Datafusion SQL API doesn't set new columns as nullable by default. + for field in field_schemas: + table.alter_columns({"path": field.name, "nullable": field.nullable}) - for field, null_array in zip(new_fields, null_arrays): - arrow_table = arrow_table.append_column(field, null_array) + # TODO: Update method below doesn't work for bulk NULL assignments, raise with LanceDB developers. + # table.update(values={field.name: None}) - try: - return self.db_client.create_table(table_name, arrow_table, mode="overwrite") except OSError: # Error occurred while creating the table, skip. return None @@ -474,7 +475,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) - self.add_table_fields(fq_table_name, field_schemas) + self.extend_lancedb_table_schema(fq_table_name, field_schemas) else: if table_name not in self.schema.dlt_table_names(): embedding_fields = get_columns_names_with_prop( diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index c7cceec274..db624aeb12 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -1,6 +1,5 @@ """Utilities for creating arrow schemas from table schemas.""" -from dlt.common.json import json from typing import ( List, cast, @@ -11,6 +10,7 @@ from lancedb.embeddings import TextEmbeddingFunction # type: ignore from typing_extensions import TypeAlias +from dlt.common.json import json from dlt.common.schema import Schema, TColumnSchema from dlt.common.typing import DictStrAny from dlt.destinations.type_mapping import TypeMapper @@ -82,3 +82,22 @@ def make_arrow_table_schema( metadata["embedding_functions"] = json.dumps(embedding_functions).encode("utf-8") return pa.schema(arrow_schema, metadata=metadata) + + +def arrow_datatype_to_fusion_datatype(arrow_type: TArrowSchema) -> str: + type_map = { + pa.bool_(): "BOOLEAN", + pa.int64(): "BIGINT", + pa.float64(): "DOUBLE", + pa.utf8(): "STRING", + pa.binary(): "BYTEA", + pa.date32(): "DATE", + } + + if isinstance(arrow_type, pa.Decimal128Type): + return f"DECIMAL({arrow_type.precision}, {arrow_type.scale})" + + if isinstance(arrow_type, pa.TimestampType): + return "TIMESTAMP" + + return type_map.get(arrow_type, "UNKNOWN") diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 9091d2db1e..0e6d4744a7 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -74,3 +74,5 @@ def set_non_standard_providers_environment_variables( ) -> None: if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" + + From 4fc73dd04d520592b08e24417a3710305b1fc2bf Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 6 Aug 2024 00:31:14 +0200 Subject: [PATCH 038/113] Optimize orphaned chunks removal for large datasets Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 9554d12a94..a068842814 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -87,6 +87,7 @@ TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +BATCH_PROCESS_CHUNK_SIZE = 10_000 class LanceDBTypeMapper(TypeMapper): @@ -859,15 +860,25 @@ def execute(self) -> None: try: if self.fq_parent_table_name: - # Chunks and embeddings in child table. - parent_ids = set(pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist()) - child_ids = set(pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist()) - - if orphaned_ids := child_ids - parent_ids: - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") + parent_id_iter: "pa.RecordBatchReader" = ( + parent_table.to_lance().scanner(columns=["_dlt_id"]).to_reader() + ) + all_parent_ids = set() + + for batch in parent_id_iter: + chunk_ids = set(pc.unique(batch["_dlt_id"]).to_pylist()) + all_parent_ids.update(chunk_ids) + + # Delete it from db and clear memory. + if len(all_parent_ids) >= BATCH_PROCESS_CHUNK_SIZE: + delete_condition = f"_dlt_parent_id NOT IN {tuple(all_parent_ids)}" + child_table.delete(delete_condition) + all_parent_ids.clear() + + # Process any remaining IDs. + if all_parent_ids: + delete_condition = f"_dlt_parent_id NOT IN {tuple(all_parent_ids)}" + child_table.delete(delete_condition) else: # Chunks and embeddings in the root table. From 9378f505946c064c200719ad36dda4019a8e156b Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 6 Aug 2024 20:05:50 +0200 Subject: [PATCH 039/113] Projection pushdown Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a068842814..bc57842989 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -860,25 +860,26 @@ def execute(self) -> None: try: if self.fq_parent_table_name: - parent_id_iter: "pa.RecordBatchReader" = ( - parent_table.to_lance().scanner(columns=["_dlt_id"]).to_reader() + # Chunks and embeddings in child table. + # By referencing underlying lance dataset we benefit from projection push-down to storage layer (LanceDB). + parent_ids = set( + pc.unique( + parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] + ).to_pylist() + ) + child_ids = set( + pc.unique( + child_table.to_lance().to_table(columns=["_dlt_parent_id"])[ + "_dlt_parent_id" + ] + ).to_pylist() ) - all_parent_ids = set() - - for batch in parent_id_iter: - chunk_ids = set(pc.unique(batch["_dlt_id"]).to_pylist()) - all_parent_ids.update(chunk_ids) - - # Delete it from db and clear memory. - if len(all_parent_ids) >= BATCH_PROCESS_CHUNK_SIZE: - delete_condition = f"_dlt_parent_id NOT IN {tuple(all_parent_ids)}" - child_table.delete(delete_condition) - all_parent_ids.clear() - # Process any remaining IDs. - if all_parent_ids: - delete_condition = f"_dlt_parent_id NOT IN {tuple(all_parent_ids)}" - child_table.delete(delete_condition) + if orphaned_ids := child_ids - parent_ids: + if len(orphaned_ids) > 1: + child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: # Chunks and embeddings in the root table. @@ -890,8 +891,8 @@ def execute(self) -> None: or self.id_field_name ) grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] - child_table_arrow: pa.Table = child_table.to_arrow().select( - [*grouping_key, "_dlt_load_id", "_dlt_id"] + child_table_arrow: pa.Table = child_table.to_lance().to_table( + columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] ) grouped = child_table_arrow.group_by(grouping_key).aggregate( From 9b14583ad242a39d744ebdb26cbffe5c691d8ff1 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 6 Aug 2024 21:20:22 +0200 Subject: [PATCH 040/113] Format Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 0e6d4744a7..9091d2db1e 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -74,5 +74,3 @@ def set_non_standard_providers_environment_variables( ) -> None: if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" - - From e21f61bbb7ef804b36e12461c23df6a986b74219 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 6 Aug 2024 22:36:28 +0200 Subject: [PATCH 041/113] Prevent primary key and document ID hint conflict in merge disposition Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 15 ++++-- tests/load/lancedb/test_pipeline.py | 53 +++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index bc57842989..8d4c031947 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -42,6 +42,7 @@ NewLoadJob, FollowupJob, ) +from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import ( @@ -884,12 +885,18 @@ def execute(self) -> None: else: # Chunks and embeddings in the root table. + document_id_field = get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) + if document_id_field and get_columns_names_with_prop( + self.table_schema, "primary_key" + ): + raise SystemConfigurationException( + "You CANNOT specify a primary key AND a document ID hint for the same" + " resource when using merge disposition." + ) + # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. - grouping_key = ( - get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) - or self.id_field_name - ) + grouping_key = document_id_field or self.id_field_name grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] child_table_arrow: pa.Table = child_table.to_lance().to_table( columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index ea52e2f472..4b964604e6 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -16,6 +16,7 @@ ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient from dlt.extract import DltResource +from dlt.pipeline.exceptions import PipelineStepFailed from tests.load.lancedb.utils import assert_table, chunk_document, mock_embed from tests.load.utils import sequence_generator, drop_active_pipeline_data from tests.pipeline.utils import assert_load_info @@ -646,3 +647,55 @@ def documents_source( tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] df = tbl.to_pandas() assert set(df["chunk_text"]) == expected_text + + +def test_primary_key_not_compatible_with_doc_id_hint_on_merge_disposition() -> None: + @dlt.resource( # type: ignore + write_disposition="merge", + table_name="document", + primary_key="doc_id", + columns={"doc_id": {DOCUMENT_ID_HINT: True}}, + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + chunks = chunk_document(doc["text"]) + embeddings = [ + { + "chunk_hash": digest128(chunk), + "chunk_text": chunk, + "embedding": mock_embed(), + } + for chunk in chunks + ] + yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} + + @dlt.source(max_table_nesting=1) + def documents_source( + docs: List[DictStrAny], + ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: + return documents(docs) + + pipeline = dlt.pipeline( + pipeline_name="test_mandatory_doc_id_hint_on_merge_disposition", + destination="lancedb", + dataset_name="test_mandatory_doc_id_hint_on_merge_disposition", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + with pytest.raises(PipelineStepFailed): + pipeline.run(documents(initial_docs)) From 9725d0e9fe0c47ca1120c364d8c02f70a1042223 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 7 Aug 2024 15:57:38 +0200 Subject: [PATCH 042/113] Add recommended file size for LanceDB destination Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index beab6a4031..9acb82344b 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -30,6 +30,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.decimal_precision = (38, 18) caps.timestamp_precision = 6 + caps.recommended_file_size = 128_000_000 + return caps @property From 5238c1125afc234cd08a734b2454c8d0de4f795c Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 7 Aug 2024 16:04:07 +0200 Subject: [PATCH 043/113] Improve comment clarity for projection push-down in LanceDB Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8d4c031947..951c8c725d 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -862,7 +862,8 @@ def execute(self) -> None: try: if self.fq_parent_table_name: # Chunks and embeddings in child table. - # By referencing underlying lance dataset we benefit from projection push-down to storage layer (LanceDB). + + # By referencing the underlying lance dataset, we benefit from projection push-down to the storage layer (LanceDB). parent_ids = set( pc.unique( parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] From c8f74680db23deec7219c88223414c7d4af15a48 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 8 Aug 2024 01:09:41 +0200 Subject: [PATCH 044/113] Update to new load interface Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 45 +++++++++---------- dlt/destinations/impl/lancedb/utils.py | 2 +- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 86f4ce5854..7c69f5db8a 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -38,8 +38,6 @@ RunnableLoadJob, StorageSchemaInfo, StateInfo, - TLoadJobState, - NewLoadJob, FollowupJob, LoadJob, ) @@ -79,8 +77,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs -from dlt.destinations.job_impl import EmptyLoadJob, NewLoadJobImpl +from dlt.destinations.job_impl import FollowupJobImpl from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -153,7 +150,7 @@ def from_db_type( ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale)==self.capabilities.wei_precision: + if (precision, scale) == self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) @@ -193,9 +190,9 @@ def write_to_db( try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition=="replace": + elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition=="merge": + elif write_disposition == "merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( @@ -244,7 +241,7 @@ def __init__( self.config.credentials.embedding_model_provider_api_key, ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider=="openai": + if embedding_model_provider == "openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings self.model_func = PatchedOpenAIEmbeddings( @@ -339,7 +336,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name!=self.sentinel_table] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -578,7 +575,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" ).sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows==0: + if joined_table.num_rows == 0: return None state = joined_table.take([0]).to_pylist()[0] @@ -711,7 +708,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: assert completed_table_chain_jobs is not None jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs @@ -722,7 +719,7 @@ def create_table_chain_completed_followup_jobs( continue # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition")=="merge": + if table.get("write_disposition") == "merge": parent_table = table.get("parent") jobs.append( LanceDBRemoveOrphansJob( @@ -742,7 +739,7 @@ def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LanceDBLoadJob(RunnableLoadJob, FollowupJob): +class LanceDBLoadJob(RunnableLoadJob): arrow_schema: TArrowSchema def __init__( @@ -765,7 +762,9 @@ def run(self) -> None: self._embedding_model_dimensions: int = self._job_client.config.embedding_model_dimensions self._id_field_name: str = self._job_client.config.id_field_name - unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema(self._load_table) + unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( + self._load_table + ) write_disposition: TWriteDisposition = cast( TWriteDisposition, self._load_table.get("write_disposition", "append") ) @@ -776,9 +775,9 @@ def run(self) -> None: if self._load_table not in self._schema.dlt_tables(): arrow_table = generate_arrow_uuid_column( arrow_table, - unique_identifiers=self.unique_identifiers, - table_name=self.fq_table_name, - id_field_name=self.id_field_name, + unique_identifiers=unique_identifiers, + table_name=self._fq_table_name, + id_field_name=self._id_field_name, ) write_to_db( @@ -790,7 +789,7 @@ def run(self) -> None: ) -class LanceDBRemoveOrphansJob(NewLoadJobImpl): +class LanceDBRemoveOrphansJob(FollowupJobImpl): def __init__( self, db_client: DBConnection, @@ -817,7 +816,6 @@ def __init__( super().__init__( file_name=job_id, - status="running", ) self._save_text_file("") @@ -827,7 +825,7 @@ def __init__( def execute(self) -> None: orphaned_ids: Set[str] - if self.write_disposition!="merge": + if self.write_disposition != "merge": raise DestinationTerminalException( f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." @@ -866,7 +864,7 @@ def execute(self) -> None: if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids)==1: + elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: @@ -898,13 +896,10 @@ def execute(self) -> None: if len(orphaned_ids) > 1: child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids)==1: + elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") except ArrowInvalid as e: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e - - def state(self) -> TLoadJobState: - return "completed" diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 9091d2db1e..466111dc98 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -20,7 +20,7 @@ # TODO: Update `generate_arrow_uuid_column` when pyarrow 17.0.0 becomes available with vectorized operations (batched + memory-mapped) def generate_arrow_uuid_column( - table: pa.Table, unique_identifiers: List[str], id_field_name: str, table_name: str + table: pa.Table, unique_identifiers: Sequence[str], id_field_name: str, table_name: str ) -> pa.Table: """Generates deterministic UUID - used for deduplication, returning a new arrow table with added UUID column. From af561918ccf01ac6e0b08459456cd698e8a92309 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 8 Aug 2024 23:00:03 +0200 Subject: [PATCH 045/113] Remove unnecessary LanceDBLoadJob attributes Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 86 +++++++++++-------- dlt/destinations/sql_jobs.py | 1 - 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 7c69f5db8a..c16f67e702 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -19,6 +19,7 @@ import pyarrow as pa import pyarrow.compute as pc import pyarrow.parquet as pq +import yaml from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore @@ -692,16 +693,9 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - parent_table = table.get("parent") - return LanceDBLoadJob( file_path=file_path, - type_mapper=self.type_mapper, - model_func=self.model_func, fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None - ), ) def create_table_chain_completed_followup_jobs( @@ -714,25 +708,25 @@ def create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) - for table in table_chain: - if table in self.schema.dlt_tables(): - continue - - # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition") == "merge": - parent_table = table.get("parent") - jobs.append( - LanceDBRemoveOrphansJob( - db_client=self.db_client, - table_schema=self.prepare_load_table(table["name"]), - fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None - ), - client_config=self.config, - ) - ) - + # for table in table_chain: + # if table in self.schema.dlt_tables(): + # continue + # + # # Only tables with merge disposition are dispatched for orphan removal jobs. + # if table.get("write_disposition") == "merge": + # parent_table = table.get("parent") + # jobs.append( + # LanceDBRemoveOrphansJob( + # db_client=self.db_client, + # table_schema=self.prepare_load_table(table["name"]), + # fq_table_name=self.make_qualified_table_name(table["name"]), + # fq_parent_table_name=( + # self.make_qualified_table_name(parent_table) if parent_table else None + # ), + # client_config=self.config, + # ) + # ) + # return jobs def table_exists(self, table_name: str) -> bool: @@ -745,21 +739,14 @@ class LanceDBLoadJob(RunnableLoadJob): def __init__( self, file_path: str, - type_mapper: LanceDBTypeMapper, - model_func: TextEmbeddingFunction, fq_table_name: str, - fq_parent_table_name: Optional[str], ) -> None: super().__init__(file_path) - self._type_mapper: TypeMapper = type_mapper self._fq_table_name: str = fq_table_name - self._model_func = model_func self._job_client: "LanceDBClient" = None def run(self) -> None: self._db_client: DBConnection = self._job_client.db_client - self._embedding_model_func: TextEmbeddingFunction = self._model_func - self._embedding_model_dimensions: int = self._job_client.config.embedding_model_dimensions self._id_field_name: str = self._job_client.config.id_field_name unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( @@ -822,7 +809,7 @@ def __init__( self.execute() - def execute(self) -> None: + def run(self) -> None: orphaned_ids: Set[str] if self.write_disposition != "merge": @@ -903,3 +890,34 @@ def execute(self) -> None: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e + + @classmethod + def from_table_chain( + cls, + table_chain: Sequence[TTableSchema], + ) -> FollowupJobImpl: + """Generates a list of orphan removal tasks that the client will execute when the job is executed in the loader. + + The `table_chain` contains a listo of table schemas with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). + """ + top_table = table_chain[0] + file_info = ParsedLoadJobFileName( + top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet" + ) + try: + job = cls(file_info.file_name()) + # Write parquet file contents?? + except Exception as e: + raise LanceDBJobCreationException(e, table_chain) from e + return job + + +class LanceDBJobCreationException(DestinationTransientException): + def __init__(self, original_exception: Exception, table_chain: Sequence[TTableSchema]) -> None: + tables_chain = yaml.dump( + table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False + ) + super().__init__( + f"Could not create SQLFollowupJob with exception {str(original_exception)}. Table" + f" chain: {tables_chain}" + ) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index cddae52bb7..643264968a 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -75,7 +75,6 @@ def from_table_chain( job = cls(file_info.file_name()) job._save_text_file("\n".join(sql)) except Exception as e: - # raise exception with some context raise SqlJobCreationException(e, table_chain) from e return job From 7e33011f52f35759b0a1ec1fe50906daf4762efc Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 9 Aug 2024 00:31:29 +0200 Subject: [PATCH 046/113] Change instance attributes to `run` method as variables Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 126 ++++++------------ 1 file changed, 40 insertions(+), 86 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index c16f67e702..9ec9288d6b 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -78,7 +78,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) -from dlt.destinations.job_impl import FollowupJobImpl +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -693,41 +693,17 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - return LanceDBLoadJob( - file_path=file_path, - fq_table_name=self.make_qualified_table_name(table["name"]), - ) + if ReferenceFollowupJob.is_reference_job(file_path): + return LanceDBRemoveOrphansJob(file_path, table) + else: + return LanceDBLoadJob(file_path, table) def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - assert completed_table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs - ) - - # for table in table_chain: - # if table in self.schema.dlt_tables(): - # continue - # - # # Only tables with merge disposition are dispatched for orphan removal jobs. - # if table.get("write_disposition") == "merge": - # parent_table = table.get("parent") - # jobs.append( - # LanceDBRemoveOrphansJob( - # db_client=self.db_client, - # table_schema=self.prepare_load_table(table["name"]), - # fq_table_name=self.make_qualified_table_name(table["name"]), - # fq_parent_table_name=( - # self.make_qualified_table_name(parent_table) if parent_table else None - # ), - # client_config=self.config, - # ) - # ) - # - return jobs + return [LanceDBRemoveOrphansJob.from_table_chain(table_chain)] def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -739,16 +715,16 @@ class LanceDBLoadJob(RunnableLoadJob): def __init__( self, file_path: str, - fq_table_name: str, + table_schema: TTableSchema, ) -> None: super().__init__(file_path) - self._fq_table_name: str = fq_table_name self._job_client: "LanceDBClient" = None + self._table_schema: TTableSchema = table_schema def run(self) -> None: - self._db_client: DBConnection = self._job_client.db_client - self._id_field_name: str = self._job_client.config.id_field_name - + db_client: DBConnection = self._job_client.db_client + fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) + id_field_name: str = self._job_client.config.id_field_name unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( self._load_table ) @@ -763,67 +739,44 @@ def run(self) -> None: arrow_table = generate_arrow_uuid_column( arrow_table, unique_identifiers=unique_identifiers, - table_name=self._fq_table_name, - id_field_name=self._id_field_name, + table_name=fq_table_name, + id_field_name=id_field_name, ) write_to_db( arrow_table, - db_client=self._db_client, - table_name=self._fq_table_name, + db_client=db_client, + table_name=fq_table_name, write_disposition=write_disposition, - id_field_name=self._id_field_name, + id_field_name=id_field_name, ) -class LanceDBRemoveOrphansJob(FollowupJobImpl): +class LanceDBRemoveOrphansJob(RunnableLoadJob): + orphaned_ids: Set[str] + def __init__( self, - db_client: DBConnection, + file_path: str, table_schema: TTableSchema, - client_config: LanceDBClientConfiguration, - fq_table_name: str, - fq_parent_table_name: Optional[str], ) -> None: - self.db_client = db_client - self.table_schema: TTableSchema = table_schema - self.fq_table_name: str = fq_table_name - self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition") - ) - self.id_field_name: str = client_config.id_field_name - - job_id = ParsedLoadJobFileName( - table_schema["name"], - ParsedLoadJobFileName.new_file_id(), - 0, - "parquet", - ).file_name() - - super().__init__( - file_name=job_id, - ) - - self._save_text_file("") - - self.execute() + super().__init__(file_path) + self._job_client: "LanceDBClient" = None + self._table_schema: TTableSchema = table_schema def run(self) -> None: - orphaned_ids: Set[str] - - if self.write_disposition != "merge": - raise DestinationTerminalException( - f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" - " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." - ) + db_client: DBConnection = self._job_client.db_client + fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) + fq_parent_table_name: str = self._job_client.make_qualified_table_name( + self._table_schema["parent"] + ) + id_field_name: str = self._job_client.config.id_field_name - # Orphans are removed irrespective of which merge strategy is picked. try: - child_table = self.db_client.open_table(self.fq_table_name) + child_table = db_client.open_table(fq_table_name) child_table.checkout_latest() - if self.fq_parent_table_name: - parent_table = self.db_client.open_table(self.fq_parent_table_name) + if fq_parent_table_name: + parent_table = db_client.open_table(fq_parent_table_name) parent_table.checkout_latest() except FileNotFoundError as e: raise DestinationTransientException( @@ -831,10 +784,9 @@ def run(self) -> None: ) from e try: - if self.fq_parent_table_name: + if fq_parent_table_name: # Chunks and embeddings in child table. - # By referencing the underlying lance dataset, we benefit from projection push-down to the storage layer (LanceDB). parent_ids = set( pc.unique( parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] @@ -857,9 +809,11 @@ def run(self) -> None: else: # Chunks and embeddings in the root table. - document_id_field = get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) + document_id_field = get_columns_names_with_prop( + self._table_schema, DOCUMENT_ID_HINT + ) if document_id_field and get_columns_names_with_prop( - self.table_schema, "primary_key" + self._table_schema, "primary_key" ): raise SystemConfigurationException( "You CANNOT specify a primary key AND a document ID hint for the same" @@ -868,7 +822,7 @@ def run(self) -> None: # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. - grouping_key = document_id_field or self.id_field_name + grouping_key = document_id_field or id_field_name grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] child_table_arrow: pa.Table = child_table.to_lance().to_table( columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] @@ -895,10 +849,10 @@ def run(self) -> None: def from_table_chain( cls, table_chain: Sequence[TTableSchema], - ) -> FollowupJobImpl: + ) -> "LanceDBRemoveOrphansJob": """Generates a list of orphan removal tasks that the client will execute when the job is executed in the loader. - The `table_chain` contains a listo of table schemas with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). + The `table_chain` contains a list of table schemas with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). """ top_table = table_chain[0] file_info = ParsedLoadJobFileName( From ee7dd0260c65ae44b664744e9174e82c7b53f8e6 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 9 Aug 2024 23:36:19 +0200 Subject: [PATCH 047/113] Schedule follow up refernce job Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 50 +++++-------------- 1 file changed, 13 insertions(+), 37 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 9ec9288d6b..64a8386eea 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -19,7 +19,6 @@ import pyarrow as pa import pyarrow.compute as pc import pyarrow.parquet as pq -import yaml from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore @@ -41,6 +40,8 @@ StateInfo, FollowupJob, LoadJob, + HasFollowupJobs, + TLoadJobState, ) from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta @@ -53,7 +54,7 @@ TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName +from dlt.common.storages import FileStorage, LoadJobInfo from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -696,14 +697,17 @@ def create_load_job( if ReferenceFollowupJob.is_reference_job(file_path): return LanceDBRemoveOrphansJob(file_path, table) else: - return LanceDBLoadJob(file_path, table) + return LanceDBLoadJobWithFollowup(file_path, table) def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - return [LanceDBRemoveOrphansJob.from_table_chain(table_chain)] + table_job_paths = [job.file_path for job in completed_table_chain_jobs] + file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) + job = ReferenceFollowupJob(file_name, table_job_paths) + return [job] def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -752,6 +756,11 @@ def run(self) -> None: ) +class LanceDBLoadJobWithFollowup(HasFollowupJobs, LanceDBLoadJob): + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: + return super().create_followup_jobs(final_state) + + class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] @@ -786,7 +795,6 @@ def run(self) -> None: try: if fq_parent_table_name: # Chunks and embeddings in child table. - parent_ids = set( pc.unique( parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] @@ -808,7 +816,6 @@ def run(self) -> None: else: # Chunks and embeddings in the root table. - document_id_field = get_columns_names_with_prop( self._table_schema, DOCUMENT_ID_HINT ) @@ -844,34 +851,3 @@ def run(self) -> None: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e - - @classmethod - def from_table_chain( - cls, - table_chain: Sequence[TTableSchema], - ) -> "LanceDBRemoveOrphansJob": - """Generates a list of orphan removal tasks that the client will execute when the job is executed in the loader. - - The `table_chain` contains a list of table schemas with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). - """ - top_table = table_chain[0] - file_info = ParsedLoadJobFileName( - top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet" - ) - try: - job = cls(file_info.file_name()) - # Write parquet file contents?? - except Exception as e: - raise LanceDBJobCreationException(e, table_chain) from e - return job - - -class LanceDBJobCreationException(DestinationTransientException): - def __init__(self, original_exception: Exception, table_chain: Sequence[TTableSchema]) -> None: - tables_chain = yaml.dump( - table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False - ) - super().__init__( - f"Could not create SQLFollowupJob with exception {str(original_exception)}. Table" - f" chain: {tables_chain}" - ) From df498abf78df8d137c28142d02ee63e912f77525 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 11:08:41 +0200 Subject: [PATCH 048/113] Add follow up lancedb remove orphan job skeleron Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 64a8386eea..23cb6cd542 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -41,7 +41,6 @@ FollowupJob, LoadJob, HasFollowupJobs, - TLoadJobState, ) from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta @@ -54,7 +53,7 @@ TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage, LoadJobInfo +from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -79,7 +78,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import FollowupJobImpl from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -694,7 +693,7 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if ReferenceFollowupJob.is_reference_job(file_path): + if file_path.endswith(".remove_orphans"): return LanceDBRemoveOrphansJob(file_path, table) else: return LanceDBLoadJobWithFollowup(file_path, table) @@ -704,10 +703,22 @@ def create_table_chain_completed_followup_jobs( table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - table_job_paths = [job.file_path for job in completed_table_chain_jobs] - file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) - job = ReferenceFollowupJob(file_name, table_job_paths) - return [job] + assert completed_table_chain_jobs is not None + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) + for table in table_chain: + if ( + table.get("write_disposition") == "merge" + and table["name"] not in self.schema.dlt_table_names() + ): + file_name = repr( + ParsedLoadJobFileName( + table["name"], ParsedLoadJobFileName.new_file_id(), 0, "remove_orphans" + ) + ) + jobs.append(FollowupJobImpl(file_name)) + return jobs def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -757,8 +768,7 @@ def run(self) -> None: class LanceDBLoadJobWithFollowup(HasFollowupJobs, LanceDBLoadJob): - def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: - return super().create_followup_jobs(final_state) + pass class LanceDBRemoveOrphansJob(RunnableLoadJob): @@ -851,3 +861,8 @@ def run(self) -> None: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e + + +class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): + def __init__(self, file_name: str) -> None: + super().__init__(file_name) From c08f1ba948a88eba759162da48e527e1f9c36cb4 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 11:33:07 +0200 Subject: [PATCH 049/113] Write empty follow up file Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 23cb6cd542..180750d8ad 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,3 +1,4 @@ +import os from types import TracebackType from typing import ( List, @@ -693,7 +694,7 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if file_path.endswith(".remove_orphans"): + if LanceDBRemoveOrphansFollowupJob.is_remove_orphan_job(file_path): return LanceDBRemoveOrphansJob(file_path, table) else: return LanceDBLoadJobWithFollowup(file_path, table) @@ -703,22 +704,7 @@ def create_table_chain_completed_followup_jobs( table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - assert completed_table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs - ) - for table in table_chain: - if ( - table.get("write_disposition") == "merge" - and table["name"] not in self.schema.dlt_table_names() - ): - file_name = repr( - ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "remove_orphans" - ) - ) - jobs.append(FollowupJobImpl(file_name)) - return jobs + return LanceDBRemoveOrphansFollowupJob.from_table_chain(table_chain) def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -866,3 +852,22 @@ def run(self) -> None: class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): def __init__(self, file_name: str) -> None: super().__init__(file_name) + self._save_text_file("") + + @classmethod + def from_table_chain( + cls, + table_chain: Sequence[TTableSchema], + ) -> List[FollowupJobImpl]: + jobs = [] + for table in table_chain: + if table.get("write_disposition") == "merge": + file_info = ParsedLoadJobFileName( + table["name"], ParsedLoadJobFileName.new_file_id(), 0, "remove_orphans" + ) + jobs.append(cls(file_info.file_name())) + return jobs + + @staticmethod + def is_remove_orphan_job(file_path: str) -> bool: + return os.path.basename(file_path) == ".remove_orphans" From f9f94e3d5024dc7ca194758140ad3b62bd4d9597 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 15:22:41 +0200 Subject: [PATCH 050/113] Write parquet Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 180750d8ad..57aa956286 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -852,7 +852,10 @@ def run(self) -> None: class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): def __init__(self, file_name: str) -> None: super().__init__(file_name) - self._save_text_file("") + self._write_empty_parquet_file() + + def _write_empty_parquet_file(self): + pq.write_table(pa.table({}), self._file_path) @classmethod def from_table_chain( @@ -862,12 +865,14 @@ def from_table_chain( jobs = [] for table in table_chain: if table.get("write_disposition") == "merge": + # TODO: insert Identify orphan IDs into load job file here. Removal should then happen through `write_to_db` in orphan removal job. file_info = ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "remove_orphans" + table["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet" ) jobs.append(cls(file_info.file_name())) return jobs @staticmethod def is_remove_orphan_job(file_path: str) -> bool: - return os.path.basename(file_path) == ".remove_orphans" + return os.path.basename(file_path) == ".parquet" + From b374b0b73800682797db48bdb3fccf2ccd45799c Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 16:03:39 +0200 Subject: [PATCH 051/113] Add support for reference file format in LanceDB destination Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 2 +- dlt/destinations/impl/lancedb/lancedb_client.py | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index 9acb82344b..d9b92e02b9 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -17,7 +17,7 @@ class lancedb(Destination[LanceDBClientConfiguration, "LanceDBClient"]): def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "parquet" - caps.supported_loader_file_formats = ["parquet"] + caps.supported_loader_file_formats = ["parquet", "reference"] caps.max_identifier_length = 200 caps.max_column_identifier_length = 1024 diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 57aa956286..94d62fc8f1 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -852,10 +852,7 @@ def run(self) -> None: class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): def __init__(self, file_name: str) -> None: super().__init__(file_name) - self._write_empty_parquet_file() - - def _write_empty_parquet_file(self): - pq.write_table(pa.table({}), self._file_path) + self._save_text_file("") @classmethod def from_table_chain( @@ -865,14 +862,12 @@ def from_table_chain( jobs = [] for table in table_chain: if table.get("write_disposition") == "merge": - # TODO: insert Identify orphan IDs into load job file here. Removal should then happen through `write_to_db` in orphan removal job. file_info = ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet" + table["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" ) jobs.append(cls(file_info.file_name())) return jobs @staticmethod def is_remove_orphan_job(file_path: str) -> bool: - return os.path.basename(file_path) == ".parquet" - + return os.path.splitext(file_path)[1][1:] == "reference" From 2ed3301d094dc49f3786e95876e135bf7aacaa96 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 19:30:56 +0200 Subject: [PATCH 052/113] Handle parent table name resolution if it doesn't exist in Lance db remove orphan job Signed-off-by: Marcel Coetzee --- dlt/common/schema/exceptions.py | 9 ++++++--- dlt/common/schema/utils.py | 4 +++- dlt/destinations/impl/lancedb/lancedb_client.py | 17 +++++++++-------- dlt/normalize/schema.py | 5 ++++- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 1055163942..2e75b4b3a1 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -246,12 +246,15 @@ def __init__(self, schema_name: str, table_name: str, column: TColumnSchemaBase) elif column.get("primary_key"): key_type = "primary key" - msg = f"The column {column['name']} in table {table_name} did not receive any data during this load. " + msg = ( + f"The column {column['name']} in table {table_name} did not receive any data during" + " this load. " + ) if key_type or not nullable: msg += f"It is marked as non-nullable{' '+key_type} and it must have values. " msg += ( - "This can happen if you specify the column manually, for example using the 'merge_key', 'primary_key' or 'columns' argument " - "but it does not exist in the data." + "This can happen if you specify the column manually, for example using the 'merge_key'," + " 'primary_key' or 'columns' argument but it does not exist in the data." ) super().__init__(schema_name, msg) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index d879c21b3c..8b87a7e5fe 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -357,7 +357,9 @@ def is_nullable_column(col: TColumnSchemaBase) -> bool: return col.get("nullable", True) -def find_incomplete_columns(tables: List[TTableSchema]) -> Iterable[Tuple[str, TColumnSchemaBase, bool]]: +def find_incomplete_columns( + tables: List[TTableSchema], +) -> Iterable[Tuple[str, TColumnSchemaBase, bool]]: """Yields (table_name, column, nullable) for all incomplete columns in `tables`""" for table in tables: for col in table["columns"].values(): diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 94d62fc8f1..957f1595da 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -675,9 +675,7 @@ def complete_load(self, load_id: str) -> None: self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier( - "schema_version_hash" - ): None, # Payload schema must match the target schema. + self.schema.naming.normalize_identifier("schema_version_hash"): None, } ] fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) @@ -704,7 +702,7 @@ def create_table_chain_completed_followup_jobs( table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - return LanceDBRemoveOrphansFollowupJob.from_table_chain(table_chain) + return LanceDBRemoveOrphansFollowupJob.from_table_chain(table_chain) # type: ignore def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -772,9 +770,12 @@ def __init__( def run(self) -> None: db_client: DBConnection = self._job_client.db_client fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) - fq_parent_table_name: str = self._job_client.make_qualified_table_name( - self._table_schema["parent"] - ) + try: + fq_parent_table_name: str = self._job_client.make_qualified_table_name( + self._table_schema["parent"] + ) + except KeyError: + fq_parent_table_name = None # The table is a root table. id_field_name: str = self._job_client.config.id_field_name try: @@ -858,7 +859,7 @@ def __init__(self, file_name: str) -> None: def from_table_chain( cls, table_chain: Sequence[TTableSchema], - ) -> List[FollowupJobImpl]: + ) -> List["LanceDBRemoveOrphansFollowupJob"]: jobs = [] for table in table_chain: if table.get("write_disposition") == "merge": diff --git a/dlt/normalize/schema.py b/dlt/normalize/schema.py index 4967fab18f..c01d184c92 100644 --- a/dlt/normalize/schema.py +++ b/dlt/normalize/schema.py @@ -3,13 +3,16 @@ from dlt.common.schema.exceptions import UnboundColumnException from dlt.common import logger + def verify_normalized_schema(schema: Schema) -> None: """Verify the schema is valid for next stage after normalization. 1. Log warning if any incomplete nullable columns are in any data tables 2. Raise `UnboundColumnException` on incomplete non-nullable columns (e.g. missing merge/primary key) """ - for table_name, column, nullable in find_incomplete_columns(schema.data_tables(seen_data_only=True)): + for table_name, column, nullable in find_incomplete_columns( + schema.data_tables(seen_data_only=True) + ): exc = UnboundColumnException(schema.name, table_name, column) if nullable: logger.warning(str(exc)) From 0694859567f083a0a90087195481a033dc3e9231 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 15 Aug 2024 22:58:38 +0200 Subject: [PATCH 053/113] Refactor specialised orphan follow up job back to reference job Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 53 +++++++------------ 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 957f1595da..d53dda9f36 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,4 +1,3 @@ -import os from types import TracebackType from typing import ( List, @@ -54,7 +53,7 @@ TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName +from dlt.common.storages import FileStorage, LoadJobInfo from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -79,7 +78,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) -from dlt.destinations.job_impl import FollowupJobImpl +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -692,23 +691,35 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if LanceDBRemoveOrphansFollowupJob.is_remove_orphan_job(file_path): + if ReferenceFollowupJob.is_reference_job(file_path): return LanceDBRemoveOrphansJob(file_path, table) else: - return LanceDBLoadJobWithFollowup(file_path, table) + return LanceDBLoadJob(file_path, table) def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - return LanceDBRemoveOrphansFollowupJob.from_table_chain(table_chain) # type: ignore + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) + if table_chain[0].get("write_disposition") == "merge": + for table in table_chain: + table_job_paths = [ + job.file_path + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table["name"] + ] + file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) + jobs.append(ReferenceFollowupJob(file_name, table_job_paths)) + return jobs def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LanceDBLoadJob(RunnableLoadJob): +class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema def __init__( @@ -751,10 +762,6 @@ def run(self) -> None: ) -class LanceDBLoadJobWithFollowup(HasFollowupJobs, LanceDBLoadJob): - pass - - class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] @@ -848,27 +855,3 @@ def run(self) -> None: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e - - -class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): - def __init__(self, file_name: str) -> None: - super().__init__(file_name) - self._save_text_file("") - - @classmethod - def from_table_chain( - cls, - table_chain: Sequence[TTableSchema], - ) -> List["LanceDBRemoveOrphansFollowupJob"]: - jobs = [] - for table in table_chain: - if table.get("write_disposition") == "merge": - file_info = ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" - ) - jobs.append(cls(file_info.file_name())) - return jobs - - @staticmethod - def is_remove_orphan_job(file_path: str) -> bool: - return os.path.splitext(file_path)[1][1:] == "reference" From 537a2be7f02e528f587ded57f82736578e45ec22 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 17 Aug 2024 19:30:16 +0200 Subject: [PATCH 054/113] Refactor orphan removal for chunked documents Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 174 +++++++++--------- 1 file changed, 92 insertions(+), 82 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index d53dda9f36..dd59fa981a 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -53,7 +53,7 @@ TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage, LoadJobInfo +from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -692,7 +692,7 @@ def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: if ReferenceFollowupJob.is_reference_job(file_path): - return LanceDBRemoveOrphansJob(file_path, table) + return LanceDBRemoveOrphansJob(file_path) else: return LanceDBLoadJob(file_path, table) @@ -705,14 +705,17 @@ def create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) if table_chain[0].get("write_disposition") == "merge": - for table in table_chain: - table_job_paths = [ - job.file_path - for job in completed_table_chain_jobs - if job.job_file_info.table_name == table["name"] - ] - file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) - jobs.append(ReferenceFollowupJob(file_name, table_job_paths)) + # TODO: Use staging to write deletion records. For now we use only one job. + all_job_paths_ordered = [ + job.file_path + for table in table_chain + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table.get("name") + ] + root_table_file_name = FileStorage.get_file_name_from_file_path( + all_job_paths_ordered[0] + ) + jobs.append(ReferenceFollowupJob(root_table_file_name, all_job_paths_ordered)) return jobs def table_exists(self, table_name: str) -> bool: @@ -762,96 +765,103 @@ def run(self) -> None: ) +# TODO: Implement staging for this step with insert deletes. class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] def __init__( self, file_path: str, - table_schema: TTableSchema, ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None - self._table_schema: TTableSchema = table_schema + self.references = ReferenceFollowupJob.resolve_references(file_path) def run(self) -> None: db_client: DBConnection = self._job_client.db_client - fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) - try: - fq_parent_table_name: str = self._job_client.make_qualified_table_name( - self._table_schema["parent"] - ) - except KeyError: - fq_parent_table_name = None # The table is a root table. id_field_name: str = self._job_client.config.id_field_name - try: - child_table = db_client.open_table(fq_table_name) - child_table.checkout_latest() - if fq_parent_table_name: - parent_table = db_client.open_table(fq_parent_table_name) - parent_table.checkout_latest() - except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" - ) from e - - try: - if fq_parent_table_name: - # Chunks and embeddings in child table. - parent_ids = set( - pc.unique( - parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] - ).to_pylist() - ) - child_ids = set( - pc.unique( - child_table.to_lance().to_table(columns=["_dlt_parent_id"])[ - "_dlt_parent_id" - ] - ).to_pylist() + # We don't all insert jobs for each table using this method. + table_lineage: List[TTableSchema] = [] + for file_path_ in self.references: + table = self._schema.get_table(ParsedLoadJobFileName.parse(file_path_).table_name) + if table["name"] not in [table_["name"] for table_ in table_lineage]: + table_lineage.append(table) + + for table in table_lineage: + fq_table_name: str = self._job_client.make_qualified_table_name(table["name"]) + try: + fq_parent_table_name: str = self._job_client.make_qualified_table_name( + table["parent"] ) + except KeyError: + fq_parent_table_name = None # The table is a root table. + + try: + child_table = db_client.open_table(fq_table_name) + child_table.checkout_latest() + if fq_parent_table_name: + parent_table = db_client.open_table(fq_parent_table_name) + parent_table.checkout_latest() + except FileNotFoundError as e: + raise DestinationTransientException( + "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" + ) from e + + try: + if fq_parent_table_name: + # Chunks and embeddings in child table. + parent_ids = set( + pc.unique( + parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] + ).to_pylist() + ) + child_ids = set( + pc.unique( + child_table.to_lance().to_table(columns=["_dlt_parent_id"])[ + "_dlt_parent_id" + ] + ).to_pylist() + ) - if orphaned_ids := child_ids - parent_ids: - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") + if orphaned_ids := child_ids - parent_ids: + if len(orphaned_ids) > 1: + child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") - else: - # Chunks and embeddings in the root table. - document_id_field = get_columns_names_with_prop( - self._table_schema, DOCUMENT_ID_HINT - ) - if document_id_field and get_columns_names_with_prop( - self._table_schema, "primary_key" - ): - raise SystemConfigurationException( - "You CANNOT specify a primary key AND a document ID hint for the same" - " resource when using merge disposition." - ) + else: + # Chunks and embeddings in the root table. + document_id_field = get_columns_names_with_prop(table, DOCUMENT_ID_HINT) + if document_id_field and get_columns_names_with_prop(table, "primary_key"): + raise SystemConfigurationException( + "You CANNOT specify a primary key AND a document ID hint for the same" + " resource when using merge disposition." + ) - # If document ID is defined, we use this as the sole grouping key to identify stale chunks, - # else fallback to the compound `id_field_name`. - grouping_key = document_id_field or id_field_name - grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] - child_table_arrow: pa.Table = child_table.to_lance().to_table( - columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] - ) + # If document ID is defined, we use this as the sole grouping key to identify stale chunks, + # else fallback to the compound `id_field_name`. + grouping_key = document_id_field or id_field_name + grouping_key = ( + grouping_key if isinstance(grouping_key, list) else [grouping_key] + ) + child_table_arrow: pa.Table = child_table.to_lance().to_table( + columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] + ) - grouped = child_table_arrow.group_by(grouping_key).aggregate( - [("_dlt_load_id", "max")] - ) - joined = child_table_arrow.join(grouped, keys=grouping_key) - orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) - orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() + grouped = child_table_arrow.group_by(grouping_key).aggregate( + [("_dlt_load_id", "max")] + ) + joined = child_table_arrow.join(grouped, keys=grouping_key) + orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) + orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") + if len(orphaned_ids) > 1: + child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") - except ArrowInvalid as e: - raise DestinationTerminalException( - "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." - ) from e + except ArrowInvalid as e: + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e From 3d2530656325b9cbfe8fa2c3fc25625352a355ab Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 18 Aug 2024 21:13:44 +0200 Subject: [PATCH 055/113] Fix dlt system table check for name instead of object Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 2 +- dlt/destinations/impl/lancedb/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index dd59fa981a..ecdd22ca56 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -748,7 +748,7 @@ def run(self) -> None: with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - if self._load_table not in self._schema.dlt_tables(): + if self._load_table["name"] not in self._schema.dlt_table_names(): arrow_table = generate_arrow_uuid_column( arrow_table, unique_identifiers=unique_identifiers, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 466111dc98..37303686df 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -27,7 +27,7 @@ def generate_arrow_uuid_column( Args: table (pa.Table): PyArrow table to generate UUIDs for. - unique_identifiers (List[str]): A list of unique identifier column names. + unique_identifiers (Sequence[str]): A list of unique identifier column names. id_field_name (str): Name of the new UUID column. table_name (str): Name of the table. From 2ee8da11f3c720fb6190db6af69debf994d91bd3 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 18 Aug 2024 23:55:38 +0200 Subject: [PATCH 056/113] Implement staging methods Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index ecdd22ca56..a0cb6bd5e4 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,3 +1,4 @@ +import contextlib from types import TracebackType from typing import ( List, @@ -12,6 +13,7 @@ Sequence, TYPE_CHECKING, Set, + Iterator, ) import lancedb # type: ignore @@ -41,6 +43,7 @@ FollowupJob, LoadJob, HasFollowupJobs, + WithStagingDataset, ) from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta @@ -210,10 +213,12 @@ def write_to_db( ) from e -class LanceDBClient(JobClientBase, WithStateSync): +class LanceDBClient(JobClientBase, WithStateSync, WithStagingDataset): """LanceDB destination handler.""" model_func: TextEmbeddingFunction + """The embedder callback used for each chunk.""" + dataset_name: str def __init__( self, @@ -231,6 +236,7 @@ def __init__( self.registry = EmbeddingFunctionRegistry.get_instance() self.type_mapper = LanceDBTypeMapper(self.capabilities) self.sentinel_table_name = config.sentinel_table_name + self.dataset_name = self.config.normalize_dataset_name(self.schema) embedding_model_provider = self.config.embedding_model_provider @@ -259,10 +265,6 @@ def __init__( self.vector_field_name = self.config.vector_field_name self.id_field_name = self.config.id_field_name - @property - def dataset_name(self) -> str: - return self.config.normalize_dataset_name(self.schema) - @property def sentinel_table(self) -> str: return self.make_qualified_table_name(self.sentinel_table_name) @@ -721,6 +723,20 @@ def create_table_chain_completed_followup_jobs( def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() + @contextlib.contextmanager + def with_staging_dataset(self) -> Iterator["LanceDBClient"]: + current_dataset_name = self.dataset_name + try: + self.dataset_name = self.schema.naming.normalize_table_identifier( + f"{current_dataset_name}_staging" + ) + yield self + finally: + self.dataset_name = current_dataset_name + + def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + return table["write_disposition"] == "merge" + class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema From 2947d55e94acd56769a3d58c4bd6db4e856d6569 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 19 Aug 2024 15:40:43 +0200 Subject: [PATCH 057/113] Override staging client methods Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a0cb6bd5e4..53da145142 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -707,7 +707,6 @@ def create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) if table_chain[0].get("write_disposition") == "merge": - # TODO: Use staging to write deletion records. For now we use only one job. all_job_paths_ordered = [ job.file_path for table in table_chain @@ -735,7 +734,7 @@ def with_staging_dataset(self) -> Iterator["LanceDBClient"]: self.dataset_name = current_dataset_name def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: - return table["write_disposition"] == "merge" + return False class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): From ea5914cf6a0a775a4d8be4ea387bc774819b6b9e Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 19 Aug 2024 16:46:03 +0200 Subject: [PATCH 058/113] Docs Signed-off-by: Marcel Coetzee --- .../dlt-ecosystem/destinations/lancedb.md | 19 +++++++++++++++++++ .../lancedb/test_remove_orphaned_records.py | 5 +---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index dbf90da4b9..d8ec8e0490 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -179,6 +179,25 @@ pipeline.run( ) ``` +#### Orphan Removal + +To maintain referential integrity between parent document tables and chunk tables, you can automatically remove orphaned chunks when updating or deleting parent documents. +Specify the "x-lancedb-doc-id" hint as follows: + +```py +pipeline.run( + lancedb_adapter( + movies, + embed="title", + document_id="id" + ), + write_disposition="merge", +) +``` + +This sets `document_id` as the primary key and uses it to remove orphans in root tables and child tables recursively. +While it's technically possible to set both a primary key, and the `document_id` hint separately, doing so leads to confusing behavior and should be avoided. + ### Append This is the default disposition. It will append the data to the existing data in the destination. diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 3118b06cc7..56a304fd35 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -216,10 +216,7 @@ def documents_source( ) -> Any: return documents(docs) - lancedb_adapter( - documents, - embed=["chunk"], - ) + lancedb_adapter(documents, embed=["chunk"], document_id="doc_id") pipeline = dlt.pipeline( pipeline_name="test_lancedb_remove_orphaned_records_with_embeddings", From 1fcce519649b5b41d7a0f707020d99cdbbf4276a Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 20 Aug 2024 13:46:11 +0200 Subject: [PATCH 059/113] Override staging client methods Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 53da145142..1081aae576 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -780,7 +780,6 @@ def run(self) -> None: ) -# TODO: Implement staging for this step with insert deletes. class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] From 8849f11cb56914ef89e5b0ba4ec59d346d1185a6 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 20 Aug 2024 21:23:22 +0200 Subject: [PATCH 060/113] Delete with inserts Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 155 +++++++----------- dlt/destinations/impl/lancedb/schema.py | 5 +- 2 files changed, 61 insertions(+), 99 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 1081aae576..82cbdcba26 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -19,13 +19,13 @@ import lancedb # type: ignore import lancedb.table # type: ignore import pyarrow as pa -import pyarrow.compute as pc import pyarrow.parquet as pq from lancedb import DBConnection +from lancedb.common import DATA # type: ignore from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore from numpy import ndarray -from pyarrow import Array, ChunkedArray, ArrowInvalid +from pyarrow import Array, ChunkedArray, ArrowInvalid, RecordBatchReader from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext @@ -45,7 +45,6 @@ HasFollowupJobs, WithStagingDataset, ) -from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import ( @@ -57,7 +56,6 @@ ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName -from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, ) @@ -66,7 +64,6 @@ ) from dlt.destinations.impl.lancedb.lancedb_adapter import ( VECTORIZE_HINT, - DOCUMENT_ID_HINT, ) from dlt.destinations.impl.lancedb.schema import ( make_arrow_field_schema, @@ -75,6 +72,7 @@ NULL_SCHEMA, TArrowField, arrow_datatype_to_fusion_datatype, + TTableLineage, ) from dlt.destinations.impl.lancedb.utils import ( get_unique_identifiers_from_table_schema, @@ -160,14 +158,15 @@ def from_db_type( return super().from_db_type(cast(str, db_type), precision, scale) -def write_to_db( - records: Union[pa.Table, List[DictStrAny]], +def write_records( + records: DATA, /, *, db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, + remove_orphans: Optional[bool] = False, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -177,6 +176,7 @@ def write_to_db( table_name: The name of the table to insert into. id_field_name: The name of the ID field for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. + remove_orphans (bool): Whether to remove orphans after insertion or not (only merge disposition). Raises: ValueError: If the write disposition is unsupported, or `id_field_name` is not @@ -199,9 +199,12 @@ def write_to_db( elif write_disposition == "merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") - tbl.merge_insert( - id_field_name - ).when_matched_update_all().when_not_matched_insert_all().execute(records) + if remove_orphans: + tbl.merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) + else: + tbl.merge_insert( + id_field_name + ).when_matched_update_all().when_not_matched_insert_all().execute(records) else: raise DestinationTerminalException( f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" @@ -534,7 +537,7 @@ def update_schema_in_storage(self) -> None: "write_disposition" ) - write_to_db( + write_records( records, db_client=self.db_client, table_name=fq_version_table_name, @@ -683,7 +686,7 @@ def complete_load(self, load_id: str) -> None: write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) - write_to_db( + write_records( records, db_client=self.db_client, table_name=fq_loads_table_name, @@ -771,7 +774,7 @@ def run(self) -> None: id_field_name=id_field_name, ) - write_to_db( + write_records( arrow_table, db_client=db_client, table_name=fq_table_name, @@ -793,89 +796,47 @@ def __init__( def run(self) -> None: db_client: DBConnection = self._job_client.db_client - id_field_name: str = self._job_client.config.id_field_name - - # We don't all insert jobs for each table using this method. - table_lineage: List[TTableSchema] = [] - for file_path_ in self.references: - table = self._schema.get_table(ParsedLoadJobFileName.parse(file_path_).table_name) - if table["name"] not in [table_["name"] for table_ in table_lineage]: - table_lineage.append(table) - - for table in table_lineage: - fq_table_name: str = self._job_client.make_qualified_table_name(table["name"]) - try: - fq_parent_table_name: str = self._job_client.make_qualified_table_name( - table["parent"] + table_lineage: List[Tuple[TTableSchema, str, str]] = [ + ( + self._schema.get_table(ParsedLoadJobFileName.parse(file_path_).table_name), + ParsedLoadJobFileName.parse(file_path_).table_name, + file_path_, + ) + for file_path_ in self.references + ] + source_table_id_field_name = "_dlt_id" + + for table, table_name, table_path in table_lineage: + target_is_root_table = "parent" not in table + fq_table_name = self._job_client.make_qualified_table_name(table_name) + target_table_schema = pq.read_schema(table_path) + + if target_is_root_table: + target_table_id_field_name = "_dlt_id" + arrow_ds = pa.dataset.dataset(table_path) + else: + # TODO: change schema of source table id to math target table id. + target_table_id_field_name = "_dlt_parent_id" + parent_table_path = self.get_parent_path(table_lineage, table.get("parent")) + arrow_ds = pa.dataset.dataset(parent_table_path) + + arrow_rbr: RecordBatchReader + with arrow_ds.scanner( + columns=[source_table_id_field_name], + batch_size=BATCH_PROCESS_CHUNK_SIZE, + ).to_reader() as arrow_rbr: + records: pa.Table = arrow_rbr.read_all() + records_with_conforming_schema = records.join(target_table_schema.empty_table(), keys=source_table_id_field_name) + + write_records( + records_with_conforming_schema, + db_client=db_client, + id_field_name=target_table_id_field_name, + table_name=fq_table_name, + write_disposition="merge", + remove_orphans=True, ) - except KeyError: - fq_parent_table_name = None # The table is a root table. - - try: - child_table = db_client.open_table(fq_table_name) - child_table.checkout_latest() - if fq_parent_table_name: - parent_table = db_client.open_table(fq_parent_table_name) - parent_table.checkout_latest() - except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" - ) from e - - try: - if fq_parent_table_name: - # Chunks and embeddings in child table. - parent_ids = set( - pc.unique( - parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] - ).to_pylist() - ) - child_ids = set( - pc.unique( - child_table.to_lance().to_table(columns=["_dlt_parent_id"])[ - "_dlt_parent_id" - ] - ).to_pylist() - ) - - if orphaned_ids := child_ids - parent_ids: - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") - else: - # Chunks and embeddings in the root table. - document_id_field = get_columns_names_with_prop(table, DOCUMENT_ID_HINT) - if document_id_field and get_columns_names_with_prop(table, "primary_key"): - raise SystemConfigurationException( - "You CANNOT specify a primary key AND a document ID hint for the same" - " resource when using merge disposition." - ) - - # If document ID is defined, we use this as the sole grouping key to identify stale chunks, - # else fallback to the compound `id_field_name`. - grouping_key = document_id_field or id_field_name - grouping_key = ( - grouping_key if isinstance(grouping_key, list) else [grouping_key] - ) - child_table_arrow: pa.Table = child_table.to_lance().to_table( - columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] - ) - - grouped = child_table_arrow.group_by(grouping_key).aggregate( - [("_dlt_load_id", "max")] - ) - joined = child_table_arrow.join(grouped, keys=grouping_key) - orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) - orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() - - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") - - except ArrowInvalid as e: - raise DestinationTerminalException( - "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." - ) from e + @staticmethod + def get_parent_path(table_lineage: TTableLineage, table: str) -> Optional[str]: + return next(entry[1] for entry in table_lineage if entry[1] == table) diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index db624aeb12..ff6f76c07a 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -3,7 +3,7 @@ from typing import ( List, cast, - Optional, + Optional, Tuple, ) import pyarrow as pa @@ -11,7 +11,7 @@ from typing_extensions import TypeAlias from dlt.common.json import json -from dlt.common.schema import Schema, TColumnSchema +from dlt.common.schema import Schema, TColumnSchema, TTableSchema from dlt.common.typing import DictStrAny from dlt.destinations.type_mapping import TypeMapper @@ -21,6 +21,7 @@ TArrowField: TypeAlias = pa.Field NULL_SCHEMA: TArrowSchema = pa.schema([]) """Empty pyarrow Schema with no fields.""" +TTableLineage: TypeAlias = List[Tuple[TTableSchema, str, str]] def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: From c7098fdf2c2f54f4c17f9841ae86859a1aaef15a Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 20 Aug 2024 21:23:48 +0200 Subject: [PATCH 061/113] Keep with batch reader Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 82cbdcba26..ea1a9796ac 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -825,11 +825,9 @@ def run(self) -> None: columns=[source_table_id_field_name], batch_size=BATCH_PROCESS_CHUNK_SIZE, ).to_reader() as arrow_rbr: - records: pa.Table = arrow_rbr.read_all() - records_with_conforming_schema = records.join(target_table_schema.empty_table(), keys=source_table_id_field_name) write_records( - records_with_conforming_schema, + arrow_rbr, db_client=db_client, id_field_name=target_table_id_field_name, table_name=fq_table_name, From d8ddcae9e20c875486a7ea69b1d1eef0335cd4c3 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 22 Aug 2024 13:53:43 +0200 Subject: [PATCH 062/113] Remove Lancedb client's staging implementation Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index ea1a9796ac..bb512bcf2e 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -216,7 +216,7 @@ def write_records( ) from e -class LanceDBClient(JobClientBase, WithStateSync, WithStagingDataset): +class LanceDBClient(JobClientBase, WithStateSync): """LanceDB destination handler.""" model_func: TextEmbeddingFunction @@ -725,19 +725,6 @@ def create_table_chain_completed_followup_jobs( def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() - @contextlib.contextmanager - def with_staging_dataset(self) -> Iterator["LanceDBClient"]: - current_dataset_name = self.dataset_name - try: - self.dataset_name = self.schema.naming.normalize_table_identifier( - f"{current_dataset_name}_staging" - ) - yield self - finally: - self.dataset_name = current_dataset_name - - def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: - return False class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): From 17137a6033826c6189da41db0fc8588ab74da48c Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 22 Aug 2024 15:13:23 +0200 Subject: [PATCH 063/113] Insert in memory arrow table. This will be optimized Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 69 +++++++++++-------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index bb512bcf2e..3ee9e513b9 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,4 +1,3 @@ -import contextlib from types import TracebackType from typing import ( List, @@ -13,7 +12,6 @@ Sequence, TYPE_CHECKING, Set, - Iterator, ) import lancedb # type: ignore @@ -25,7 +23,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore from numpy import ndarray -from pyarrow import Array, ChunkedArray, ArrowInvalid, RecordBatchReader +from pyarrow import Array, ChunkedArray, ArrowInvalid from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext @@ -43,7 +41,6 @@ FollowupJob, LoadJob, HasFollowupJobs, - WithStagingDataset, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -200,6 +197,7 @@ def write_records( if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") if remove_orphans: + # tbl.to_lance().merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) tbl.merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) else: tbl.merge_insert( @@ -726,7 +724,6 @@ def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() - class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema @@ -791,36 +788,52 @@ def run(self) -> None: ) for file_path_ in self.references ] - source_table_id_field_name = "_dlt_id" - for table, table_name, table_path in table_lineage: - target_is_root_table = "parent" not in table - fq_table_name = self._job_client.make_qualified_table_name(table_name) - target_table_schema = pq.read_schema(table_path) + for target_table, target_table_name, target_table_path in table_lineage: + target_is_root_table = "parent" not in target_table + fq_table_name = self._job_client.make_qualified_table_name(target_table_name) if target_is_root_table: target_table_id_field_name = "_dlt_id" - arrow_ds = pa.dataset.dataset(table_path) + # arrow_ds = pa.dataset.dataset(table_path) + with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + + # Append ID Field, which is only defined in the LanceDB table. + payload_arrow_table = payload_arrow_table.append_column( + pa.field(self._job_client.id_field_name, pa.string()), + pa.array([""] * payload_arrow_table.num_rows, type=pa.string()), + ) else: # TODO: change schema of source table id to math target table id. target_table_id_field_name = "_dlt_parent_id" - parent_table_path = self.get_parent_path(table_lineage, table.get("parent")) - arrow_ds = pa.dataset.dataset(parent_table_path) - - arrow_rbr: RecordBatchReader - with arrow_ds.scanner( - columns=[source_table_id_field_name], - batch_size=BATCH_PROCESS_CHUNK_SIZE, - ).to_reader() as arrow_rbr: - - write_records( - arrow_rbr, - db_client=db_client, - id_field_name=target_table_id_field_name, - table_name=fq_table_name, - write_disposition="merge", - remove_orphans=True, - ) + parent_table_path = self.get_parent_path(table_lineage, target_table.get("parent")) + # arrow_ds = pa.dataset.dataset(parent_table_path) + with FileStorage.open_zipsafe_ro(parent_table_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + + # arrow_rbr: RecordBatchReader + # with arrow_ds.scanner( + # columns=[source_table_id_field_name], + # batch_size=BATCH_PROCESS_CHUNK_SIZE, + # ).to_reader() as arrow_rbr: + # with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: + # target_table_arrow_schema: pa.Schema = pq.read_schema(f) + + # payload_arrow_table_with_conforming_schema = payload_arrow_table.join( + # target_table_arrow_schema.empty_table(), + # keys=target_table_id_field_name, + # ) + + write_records( + # payload_arrow_table_with_conforming_schema, + payload_arrow_table, + db_client=db_client, + id_field_name=target_table_id_field_name, + table_name=fq_table_name, + write_disposition="merge", + remove_orphans=True, + ) @staticmethod def get_parent_path(table_lineage: TTableLineage, table: str) -> Optional[str]: From 53d896a1ac08d031a71b96b99b981db822805b1b Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 14:08:15 +0200 Subject: [PATCH 064/113] Rename classes to the new job implementation classes Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 131 +++++++----------- dlt/destinations/impl/lancedb/utils.py | 17 +++ 2 files changed, 67 insertions(+), 81 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 3ee9e513b9..d0ff3c29dd 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -38,9 +38,9 @@ RunnableLoadJob, StorageSchemaInfo, StateInfo, - FollowupJob, LoadJob, HasFollowupJobs, + FollowupJobRequest, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -75,8 +75,9 @@ get_unique_identifiers_from_table_schema, set_non_standard_providers_environment_variables, generate_arrow_uuid_column, + get_default_arrow_value, ) -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -197,7 +198,6 @@ def write_records( if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") if remove_orphans: - # tbl.to_lance().merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) tbl.merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) else: tbl.merge_insert( @@ -694,30 +694,17 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if ReferenceFollowupJob.is_reference_job(file_path): + if ReferenceFollowupJobRequest.is_reference_job(file_path): return LanceDBRemoveOrphansJob(file_path) else: return LanceDBLoadJob(file_path, table) - def create_table_chain_completed_followup_jobs( - self, - table_chain: Sequence[TTableSchema], - completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[FollowupJob]: - jobs = super().create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs - ) - if table_chain[0].get("write_disposition") == "merge": - all_job_paths_ordered = [ - job.file_path - for table in table_chain - for job in completed_table_chain_jobs - if job.job_file_info.table_name == table.get("name") - ] - root_table_file_name = FileStorage.get_file_name_from_file_path( - all_job_paths_ordered[0] - ) - jobs.append(ReferenceFollowupJob(root_table_file_name, all_job_paths_ordered)) + def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJobRequest]: + jobs = super().create_table_chain_completed_followup_jobs(table_chain, completed_table_chain_jobs) + if table_chain[0].get("write_disposition")=="merge": + all_job_paths_ordered = [job.file_path for table in table_chain for job in completed_table_chain_jobs if job.job_file_info.table_name==table.get("name")] + root_table_file_name = FileStorage.get_file_name_from_file_path(all_job_paths_ordered[0]) + jobs.append(ReferenceFollowupJobRequest(root_table_file_name, all_job_paths_ordered)) return jobs def table_exists(self, table_name: str) -> bool: @@ -727,11 +714,7 @@ def table_exists(self, table_name: str) -> bool: class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema - def __init__( - self, - file_path: str, - table_schema: TTableSchema, - ) -> None: + def __init__(self, file_path: str, table_schema: TTableSchema, ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None self._table_schema: TTableSchema = table_schema @@ -740,43 +723,25 @@ def run(self) -> None: db_client: DBConnection = self._job_client.db_client fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) id_field_name: str = self._job_client.config.id_field_name - unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( - self._load_table - ) - write_disposition: TWriteDisposition = cast( - TWriteDisposition, self._load_table.get("write_disposition", "append") - ) + unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema(self._load_table) + write_disposition: TWriteDisposition = cast(TWriteDisposition, self._load_table.get("write_disposition", "append")) with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) if self._load_table["name"] not in self._schema.dlt_table_names(): - arrow_table = generate_arrow_uuid_column( - arrow_table, - unique_identifiers=unique_identifiers, - table_name=fq_table_name, - id_field_name=id_field_name, - ) + arrow_table = generate_arrow_uuid_column(arrow_table, unique_identifiers=unique_identifiers, table_name=fq_table_name, id_field_name=id_field_name, ) - write_records( - arrow_table, - db_client=db_client, - table_name=fq_table_name, - write_disposition=write_disposition, - id_field_name=id_field_name, - ) + write_records(arrow_table, db_client=db_client, table_name=fq_table_name, write_disposition=write_disposition, id_field_name=id_field_name, ) class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] - def __init__( - self, - file_path: str, - ) -> None: + def __init__(self, file_path: str, ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None - self.references = ReferenceFollowupJob.resolve_references(file_path) + self.references = ReferenceFollowupJobRequest.resolve_references(file_path) def run(self) -> None: db_client: DBConnection = self._job_client.db_client @@ -795,38 +760,42 @@ def run(self) -> None: if target_is_root_table: target_table_id_field_name = "_dlt_id" - # arrow_ds = pa.dataset.dataset(table_path) - with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: - payload_arrow_table: pa.Table = pq.read_table(f) - - # Append ID Field, which is only defined in the LanceDB table. - payload_arrow_table = payload_arrow_table.append_column( - pa.field(self._job_client.id_field_name, pa.string()), - pa.array([""] * payload_arrow_table.num_rows, type=pa.string()), - ) + file_path = target_table_path else: - # TODO: change schema of source table id to math target table id. target_table_id_field_name = "_dlt_parent_id" - parent_table_path = self.get_parent_path(table_lineage, target_table.get("parent")) - # arrow_ds = pa.dataset.dataset(parent_table_path) - with FileStorage.open_zipsafe_ro(parent_table_path, mode="rb") as f: - payload_arrow_table: pa.Table = pq.read_table(f) - - # arrow_rbr: RecordBatchReader - # with arrow_ds.scanner( - # columns=[source_table_id_field_name], - # batch_size=BATCH_PROCESS_CHUNK_SIZE, - # ).to_reader() as arrow_rbr: - # with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: - # target_table_arrow_schema: pa.Schema = pq.read_schema(f) - - # payload_arrow_table_with_conforming_schema = payload_arrow_table.join( - # target_table_arrow_schema.empty_table(), - # keys=target_table_id_field_name, - # ) + file_path = self.get_parent_path(table_lineage, target_table.get("parent")) + + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + + # Get target table schema + with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: + target_table_schema: pa.Schema = pq.read_schema(f) + + # LanceDB requires the payload to have all fields populated, even if we don't intend to use them in our merge operation. + # Unfortunately, we can't just create NULL fields; else LanceDB always truncates the target using `when_not_matched_by_source_delete`. + schema_difference = pa.schema( + set(target_table_schema) - set(payload_arrow_table.schema) + ) + for field in schema_difference: + try: + default_value = get_default_arrow_value(field.type) + default_array = pa.array( + [default_value] * payload_arrow_table.num_rows, type=field.type + ) + payload_arrow_table = payload_arrow_table.append_column(field, default_array) + except ValueError as e: + logger.warn(f"{e}. Using null values for field '{field.name}'.") + null_array = pa.array([None] * payload_arrow_table.num_rows, type=field.type) + payload_arrow_table = payload_arrow_table.append_column(field, null_array) + + # TODO: Remove special field, we don't need it. + payload_arrow_table = payload_arrow_table.append_column( + pa.field(self._job_client.id_field_name, pa.string()), + pa.array([""] * payload_arrow_table.num_rows, type=pa.string()), + ) write_records( - # payload_arrow_table_with_conforming_schema, payload_arrow_table, db_client=db_client, id_field_name=target_table_id_field_name, @@ -837,4 +806,4 @@ def run(self) -> None: @staticmethod def get_parent_path(table_lineage: TTableLineage, table: str) -> Optional[str]: - return next(entry[1] for entry in table_lineage if entry[1] == table) + return next(entry[2] for entry in table_lineage if entry[1] == table) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 37303686df..fe2bfa48bd 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,5 +1,6 @@ import os import uuid +from datetime import date, datetime from typing import Sequence, Union, Dict, List import pyarrow as pa @@ -74,3 +75,19 @@ def set_non_standard_providers_environment_variables( ) -> None: if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" + +def get_default_arrow_value(field_type): + if pa.types.is_integer(field_type): + return 0 + elif pa.types.is_floating(field_type): + return 0.0 + elif pa.types.is_string(field_type): + return "" + elif pa.types.is_boolean(field_type): + return False + elif pa.types.is_date(field_type): + return date.today() + elif pa.types.is_timestamp(field_type): + return datetime.now() + else: + raise ValueError(f"Unsupported data type: {field_type}") \ No newline at end of file From 26ba0f57f6db6d69864e0aec21103cac5d07bbb3 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 15:03:49 +0200 Subject: [PATCH 065/113] Use namedtuple for table chain to improve readability Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 86 +++++++++++++------ dlt/destinations/impl/lancedb/schema.py | 6 +- dlt/destinations/impl/lancedb/utils.py | 8 +- 3 files changed, 71 insertions(+), 29 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index d0ff3c29dd..a099974d56 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -70,6 +70,7 @@ TArrowField, arrow_datatype_to_fusion_datatype, TTableLineage, + TableJob, ) from dlt.destinations.impl.lancedb.utils import ( get_unique_identifiers_from_table_schema, @@ -699,11 +700,24 @@ def create_load_job( else: return LanceDBLoadJob(file_path, table) - def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJobRequest]: - jobs = super().create_table_chain_completed_followup_jobs(table_chain, completed_table_chain_jobs) - if table_chain[0].get("write_disposition")=="merge": - all_job_paths_ordered = [job.file_path for table in table_chain for job in completed_table_chain_jobs if job.job_file_info.table_name==table.get("name")] - root_table_file_name = FileStorage.get_file_name_from_file_path(all_job_paths_ordered[0]) + def create_table_chain_completed_followup_jobs( + self, + table_chain: Sequence[TTableSchema], + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + ) -> List[FollowupJobRequest]: + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) + if table_chain[0].get("write_disposition") == "merge": + all_job_paths_ordered = [ + job.file_path + for table in table_chain + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table.get("name") + ] + root_table_file_name = FileStorage.get_file_name_from_file_path( + all_job_paths_ordered[0] + ) jobs.append(ReferenceFollowupJobRequest(root_table_file_name, all_job_paths_ordered)) return jobs @@ -714,7 +728,11 @@ def table_exists(self, table_name: str) -> bool: class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema - def __init__(self, file_path: str, table_schema: TTableSchema, ) -> None: + def __init__( + self, + file_path: str, + table_schema: TTableSchema, + ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None self._table_schema: TTableSchema = table_schema @@ -723,53 +741,73 @@ def run(self) -> None: db_client: DBConnection = self._job_client.db_client fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) id_field_name: str = self._job_client.config.id_field_name - unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema(self._load_table) - write_disposition: TWriteDisposition = cast(TWriteDisposition, self._load_table.get("write_disposition", "append")) + unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( + self._load_table + ) + write_disposition: TWriteDisposition = cast( + TWriteDisposition, self._load_table.get("write_disposition", "append") + ) with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) if self._load_table["name"] not in self._schema.dlt_table_names(): - arrow_table = generate_arrow_uuid_column(arrow_table, unique_identifiers=unique_identifiers, table_name=fq_table_name, id_field_name=id_field_name, ) + arrow_table = generate_arrow_uuid_column( + arrow_table, + unique_identifiers=unique_identifiers, + table_name=fq_table_name, + id_field_name=id_field_name, + ) - write_records(arrow_table, db_client=db_client, table_name=fq_table_name, write_disposition=write_disposition, id_field_name=id_field_name, ) + write_records( + arrow_table, + db_client=db_client, + table_name=fq_table_name, + write_disposition=write_disposition, + id_field_name=id_field_name, + ) class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] - def __init__(self, file_path: str, ) -> None: + def __init__( + self, + file_path: str, + ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None self.references = ReferenceFollowupJobRequest.resolve_references(file_path) def run(self) -> None: db_client: DBConnection = self._job_client.db_client - table_lineage: List[Tuple[TTableSchema, str, str]] = [ - ( - self._schema.get_table(ParsedLoadJobFileName.parse(file_path_).table_name), - ParsedLoadJobFileName.parse(file_path_).table_name, - file_path_, + table_lineage: TTableLineage = [ + TableJob( + table_schema=self._schema.get_table( + ParsedLoadJobFileName.parse(file_path_).table_name + ), + table_name=ParsedLoadJobFileName.parse(file_path_).table_name, + file_path=file_path_, ) for file_path_ in self.references ] - for target_table, target_table_name, target_table_path in table_lineage: - target_is_root_table = "parent" not in target_table - fq_table_name = self._job_client.make_qualified_table_name(target_table_name) + for job in table_lineage: + target_is_root_table = "parent" not in job.table_schema + fq_table_name = self._job_client.make_qualified_table_name(job.table_name) if target_is_root_table: target_table_id_field_name = "_dlt_id" - file_path = target_table_path + file_path = job.file_path else: target_table_id_field_name = "_dlt_parent_id" - file_path = self.get_parent_path(table_lineage, target_table.get("parent")) + file_path = self.get_parent_path(table_lineage, job.table_schema.get("parent")) with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: payload_arrow_table: pa.Table = pq.read_table(f) # Get target table schema - with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: + with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: target_table_schema: pa.Schema = pq.read_schema(f) # LanceDB requires the payload to have all fields populated, even if we don't intend to use them in our merge operation. @@ -805,5 +843,5 @@ def run(self) -> None: ) @staticmethod - def get_parent_path(table_lineage: TTableLineage, table: str) -> Optional[str]: - return next(entry[2] for entry in table_lineage if entry[1] == table) + def get_parent_path(table_lineage: TTableLineage, table: str) -> Any: + return next(entry.file_path for entry in table_lineage if entry.table_name == table) diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index ff6f76c07a..6cc562038e 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -1,5 +1,5 @@ """Utilities for creating arrow schemas from table schemas.""" - +from collections import namedtuple from typing import ( List, cast, @@ -21,7 +21,8 @@ TArrowField: TypeAlias = pa.Field NULL_SCHEMA: TArrowSchema = pa.schema([]) """Empty pyarrow Schema with no fields.""" -TTableLineage: TypeAlias = List[Tuple[TTableSchema, str, str]] +TableJob = namedtuple('TableJob', ['table_schema', 'table_name', 'file_path']) +TTableLineage: TypeAlias = List[TableJob] def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: @@ -102,3 +103,4 @@ def arrow_datatype_to_fusion_datatype(arrow_type: TArrowSchema) -> str: return "TIMESTAMP" return type_map.get(arrow_type, "UNKNOWN") + diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index fe2bfa48bd..e1847120bd 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -9,6 +9,7 @@ from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider +from dlt.destinations.impl.lancedb.schema import TArrowDataType PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -62,7 +63,7 @@ def get_unique_identifiers_from_table_schema(table_schema: TTableSchema) -> List """ primary_keys = get_columns_names_with_prop(table_schema, "primary_key") merge_keys = [] - if table_schema.get("write_disposition") == "merge": + if table_schema.get("write_disposition")=="merge": merge_keys = get_columns_names_with_prop(table_schema, "merge_key") if join_keys := list(set(primary_keys + merge_keys)): return join_keys @@ -76,7 +77,8 @@ def set_non_standard_providers_environment_variables( if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" -def get_default_arrow_value(field_type): + +def get_default_arrow_value(field_type: TArrowDataType) -> object: if pa.types.is_integer(field_type): return 0 elif pa.types.is_floating(field_type): @@ -90,4 +92,4 @@ def get_default_arrow_value(field_type): elif pa.types.is_timestamp(field_type): return datetime.now() else: - raise ValueError(f"Unsupported data type: {field_type}") \ No newline at end of file + raise ValueError(f"Unsupported data type: {field_type}") From 06e04d9fe6fbb76095129d93f5872c3c42566ca8 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 17:43:03 +0200 Subject: [PATCH 066/113] Remove orphans by loading all ancestor IDs simultaneously Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 35 +++++++++++++------ dlt/destinations/impl/lancedb/utils.py | 26 +++++++++++++- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a099974d56..4747856c73 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -77,6 +77,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, get_default_arrow_value, + create_unique_table_lineage, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -711,9 +712,8 @@ def create_table_chain_completed_followup_jobs( if table_chain[0].get("write_disposition") == "merge": all_job_paths_ordered = [ job.file_path - for table in table_chain + for _ in table_chain for job in completed_table_chain_jobs - if job.job_file_info.table_name == table.get("name") ] root_table_file_name = FileStorage.get_file_name_from_file_path( all_job_paths_ordered[0] @@ -791,22 +791,35 @@ def run(self) -> None: ) for file_path_ in self.references ] + table_lineage_unique: TTableLineage = create_unique_table_lineage(table_lineage) - for job in table_lineage: + for job in table_lineage_unique: target_is_root_table = "parent" not in job.table_schema fq_table_name = self._job_client.make_qualified_table_name(job.table_name) if target_is_root_table: target_table_id_field_name = "_dlt_id" - file_path = job.file_path + ancestors_file_paths = self.get_parent_paths(table_lineage, job.table_name) else: target_table_id_field_name = "_dlt_parent_id" - file_path = self.get_parent_path(table_lineage, job.table_schema.get("parent")) + ancestors_file_paths = self.get_parent_paths( + table_lineage, job.table_schema.get("parent") + ) - with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: - payload_arrow_table: pa.Table = pq.read_table(f) + # `when_not_matched_by_source_delete` removes absent source IDs. + # Loading ancestors individually risks unintended ID deletion, necessitating simultaneous loading of all ancestor IDs. + payload_arrow_table = None + for file_path_ in ancestors_file_paths: + with FileStorage.open_zipsafe_ro(file_path_, mode="rb") as f: + ancestor_arrow_table: pa.Table = pq.read_table(f) + if payload_arrow_table is None: + payload_arrow_table = ancestor_arrow_table + else: + payload_arrow_table = pa.concat_tables( + [payload_arrow_table, ancestor_arrow_table] + ) - # Get target table schema + # Get target table schema. with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: target_table_schema: pa.Schema = pq.read_schema(f) @@ -843,5 +856,7 @@ def run(self) -> None: ) @staticmethod - def get_parent_path(table_lineage: TTableLineage, table: str) -> Any: - return next(entry.file_path for entry in table_lineage if entry.table_name == table) + def get_parent_paths(table_lineage: TTableLineage, table: str) -> List[str]: + """Return all load files for a given table in the same order in which they were + loaded, thereby maintaining the load history of the table.""" + return [entry.file_path for entry in table_lineage if entry.table_name == table] diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index e1847120bd..624f3f4a4f 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -9,7 +9,7 @@ from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider -from dlt.destinations.impl.lancedb.schema import TArrowDataType +from dlt.destinations.impl.lancedb.schema import TArrowDataType, TTableLineage PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -93,3 +93,27 @@ def get_default_arrow_value(field_type: TArrowDataType) -> object: return datetime.now() else: raise ValueError(f"Unsupported data type: {field_type}") + + +def create_unique_table_lineage(table_lineage: TTableLineage) -> TTableLineage: + """Create a unique table lineage, keeping the last job for each table. + + Args: + table_lineage: The full table lineage. + + Returns: + A new list of TableJob objects with the duplicates removed, keeping the + last occurrence of each unique table name while maintaining the + original order of appearance. + """ + seen_table_names = set() + unique_lineage = [] + + for job in reversed(table_lineage): + if job.table_name not in seen_table_names: + seen_table_names.add(job.table_name) + unique_lineage.append(job) + + return list(reversed(unique_lineage)) + + From 470315ed1d124c0d5f1f51caff52b2aa53213034 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 19:01:31 +0200 Subject: [PATCH 067/113] Fix doc_id adapter Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_adapter.py | 2 +- tests/load/lancedb/test_remove_orphaned_records.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 2f6b44d131..0daba7a651 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -56,7 +56,7 @@ def lancedb_adapter( if document_id: if isinstance(document_id, str): - embed = [document_id] + document_id = [document_id] if not isinstance(document_id, list): raise ValueError( "'document_id' must be a list of column names or a single column name as a string." diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 56a304fd35..37bca8ffb0 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -198,11 +198,10 @@ def identity_resource( def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: - @dlt.resource( # type: ignore + @dlt.resource( write_disposition="merge", table_name="document", merge_key=["chunk"], - columns={"doc_id": {DOCUMENT_ID_HINT: True}}, ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: From 43eb5b4aa87dfe05d5c839b894f297cab638af14 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 19:01:31 +0200 Subject: [PATCH 068/113] Fix doc_id adapter Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_adapter.py | 2 +- tests/load/lancedb/test_remove_orphaned_records.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 2f6b44d131..0daba7a651 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -56,7 +56,7 @@ def lancedb_adapter( if document_id: if isinstance(document_id, str): - embed = [document_id] + document_id = [document_id] if not isinstance(document_id, list): raise ValueError( "'document_id' must be a list of column names or a single column name as a string." diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 3118b06cc7..6346171574 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -198,11 +198,10 @@ def identity_resource( def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: - @dlt.resource( # type: ignore + @dlt.resource( write_disposition="merge", table_name="document", merge_key=["chunk"], - columns={"doc_id": {DOCUMENT_ID_HINT: True}}, ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: From 40a5e7372abe24f496cffa33bfc7a10863cd0faf Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 19:13:44 +0200 Subject: [PATCH 069/113] Revert to previous Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 4747856c73..c2513c022b 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -712,8 +712,9 @@ def create_table_chain_completed_followup_jobs( if table_chain[0].get("write_disposition") == "merge": all_job_paths_ordered = [ job.file_path - for _ in table_chain + for table in table_chain for job in completed_table_chain_jobs + if job.job_file_info.table_name == table.get("name") ] root_table_file_name = FileStorage.get_file_name_from_file_path( all_job_paths_ordered[0] From 8cd6003f4e0bc8b821b83ed7515bf762f8280c14 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 27 Aug 2024 13:39:57 +0200 Subject: [PATCH 070/113] Revert "Remove orphans by loading all ancestor IDs simultaneously" This reverts commit 06e04d9fe6fbb76095129d93f5872c3c42566ca8. --- .../impl/lancedb/lancedb_client.py | 32 +++++-------------- dlt/destinations/impl/lancedb/utils.py | 26 +-------------- 2 files changed, 9 insertions(+), 49 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index c2513c022b..a099974d56 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -77,7 +77,6 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, get_default_arrow_value, - create_unique_table_lineage, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -792,35 +791,22 @@ def run(self) -> None: ) for file_path_ in self.references ] - table_lineage_unique: TTableLineage = create_unique_table_lineage(table_lineage) - for job in table_lineage_unique: + for job in table_lineage: target_is_root_table = "parent" not in job.table_schema fq_table_name = self._job_client.make_qualified_table_name(job.table_name) if target_is_root_table: target_table_id_field_name = "_dlt_id" - ancestors_file_paths = self.get_parent_paths(table_lineage, job.table_name) + file_path = job.file_path else: target_table_id_field_name = "_dlt_parent_id" - ancestors_file_paths = self.get_parent_paths( - table_lineage, job.table_schema.get("parent") - ) + file_path = self.get_parent_path(table_lineage, job.table_schema.get("parent")) - # `when_not_matched_by_source_delete` removes absent source IDs. - # Loading ancestors individually risks unintended ID deletion, necessitating simultaneous loading of all ancestor IDs. - payload_arrow_table = None - for file_path_ in ancestors_file_paths: - with FileStorage.open_zipsafe_ro(file_path_, mode="rb") as f: - ancestor_arrow_table: pa.Table = pq.read_table(f) - if payload_arrow_table is None: - payload_arrow_table = ancestor_arrow_table - else: - payload_arrow_table = pa.concat_tables( - [payload_arrow_table, ancestor_arrow_table] - ) + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) - # Get target table schema. + # Get target table schema with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: target_table_schema: pa.Schema = pq.read_schema(f) @@ -857,7 +843,5 @@ def run(self) -> None: ) @staticmethod - def get_parent_paths(table_lineage: TTableLineage, table: str) -> List[str]: - """Return all load files for a given table in the same order in which they were - loaded, thereby maintaining the load history of the table.""" - return [entry.file_path for entry in table_lineage if entry.table_name == table] + def get_parent_path(table_lineage: TTableLineage, table: str) -> Any: + return next(entry.file_path for entry in table_lineage if entry.table_name == table) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 624f3f4a4f..e1847120bd 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -9,7 +9,7 @@ from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider -from dlt.destinations.impl.lancedb.schema import TArrowDataType, TTableLineage +from dlt.destinations.impl.lancedb.schema import TArrowDataType PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -93,27 +93,3 @@ def get_default_arrow_value(field_type: TArrowDataType) -> object: return datetime.now() else: raise ValueError(f"Unsupported data type: {field_type}") - - -def create_unique_table_lineage(table_lineage: TTableLineage) -> TTableLineage: - """Create a unique table lineage, keeping the last job for each table. - - Args: - table_lineage: The full table lineage. - - Returns: - A new list of TableJob objects with the duplicates removed, keeping the - last occurrence of each unique table name while maintaining the - original order of appearance. - """ - seen_table_names = set() - unique_lineage = [] - - for job in reversed(table_lineage): - if job.table_name not in seen_table_names: - seen_table_names.add(job.table_name) - unique_lineage.append(job) - - return list(reversed(unique_lineage)) - - From dad103e9e12494d4548b7428de4f64374efa8c54 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 27 Aug 2024 19:51:13 +0200 Subject: [PATCH 071/113] Remove doc_id hint Signed-off-by: Marcel Coetzee --- .../impl/lancedb/configuration.py | 2 - .../impl/lancedb/lancedb_adapter.py | 20 +- .../impl/lancedb/lancedb_client.py | 44 ++--- dlt/destinations/impl/lancedb/schema.py | 4 - tests/load/lancedb/test_pipeline.py | 175 +----------------- .../lancedb/test_remove_orphaned_records.py | 17 +- 6 files changed, 27 insertions(+), 235 deletions(-) diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index 5aa4ba714f..92f88d562b 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -93,8 +93,6 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): Make sure it corresponds with the associated embedding model's dimensionality.""" vector_field_name: str = "vector__" """Name of the special field to store the vector embeddings.""" - id_field_name: str = "id__" - """Name of the special field to manage deduplication.""" sentinel_table_name: str = "dltSentinelTable" """Name of the sentinel table that encapsulates datasets. Since LanceDB has no concept of schemas, this table serves as a proxy to group related dlt tables together.""" diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 0daba7a651..df356430e4 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -6,13 +6,11 @@ VECTORIZE_HINT = "x-lancedb-embed" -DOCUMENT_ID_HINT = "x-lancedb-doc-id" def lancedb_adapter( data: Any, embed: TColumnNames = None, - document_id: TColumnNames = None, ) -> DltResource: """Prepares data for the LanceDB destination by specifying which columns should be embedded. @@ -22,8 +20,6 @@ def lancedb_adapter( object. embed (TColumnNames, optional): Specify columns to generate embeddings for. It can be a single column name as a string, or a list of column names. - document_id (TColumnNames, optional): Specify columns which represenet the document - and which will be appended to primary/merge keys. Returns: DltResource: A resource with applied LanceDB-specific hints. @@ -54,22 +50,8 @@ def lancedb_adapter( VECTORIZE_HINT: True, # type: ignore[misc] } - if document_id: - if isinstance(document_id, str): - document_id = [document_id] - if not isinstance(document_id, list): - raise ValueError( - "'document_id' must be a list of column names or a single column name as a string." - ) - - for column_name in document_id: - column_hints[column_name] = { - "name": column_name, - DOCUMENT_ID_HINT: True, # type: ignore[misc] - } - if not column_hints: - raise ValueError("At least one of 'embed' or 'document_id' must be specified.") + raise ValueError("You must must provide the 'embed' argument if using the adapter.") else: resource.apply_hints(columns=column_hints) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a099974d56..57639a139a 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -73,9 +73,7 @@ TableJob, ) from dlt.destinations.impl.lancedb.utils import ( - get_unique_identifiers_from_table_schema, set_non_standard_providers_environment_variables, - generate_arrow_uuid_column, get_default_arrow_value, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest @@ -164,7 +162,7 @@ def write_records( db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", - id_field_name: Optional[str] = None, + merge_key: Optional[Union[str, List[str]]] = None, remove_orphans: Optional[bool] = False, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -173,7 +171,7 @@ def write_records( records: The data to be inserted as payload. db_client: The LanceDB client connection. table_name: The name of the table to insert into. - id_field_name: The name of the ID field for update/merge operations. + merge_key: Keys for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. remove_orphans (bool): Whether to remove orphans after insertion or not (only merge disposition). @@ -196,13 +194,11 @@ def write_records( elif write_disposition == "replace": tbl.add(records, mode="overwrite") elif write_disposition == "merge": - if not id_field_name: - raise ValueError("To perform a merge update, 'id_field_name' must be specified.") if remove_orphans: - tbl.merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) + tbl.merge_insert(merge_key).when_not_matched_by_source_delete().execute(records) else: tbl.merge_insert( - id_field_name + merge_key ).when_matched_update_all().when_not_matched_insert_all().execute(records) else: raise DestinationTerminalException( @@ -265,7 +261,6 @@ def __init__( ) self.vector_field_name = self.config.vector_field_name - self.id_field_name = self.config.id_field_name @property def sentinel_table(self) -> str: @@ -488,7 +483,6 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: self.schema.get_table(table_name=table_name), VECTORIZE_HINT ) vector_field_name = self.vector_field_name - id_field_name = self.id_field_name embedding_model_func = self.model_func embedding_model_dimensions = self.config.embedding_model_dimensions else: @@ -506,7 +500,6 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: embedding_model_func=embedding_model_func, embedding_model_dimensions=embedding_model_dimensions, vector_field_name=vector_field_name, - id_field_name=id_field_name, ) fq_table_name = self.make_qualified_table_name(table_name) self.create_table(fq_table_name, table_schema) @@ -740,10 +733,6 @@ def __init__( def run(self) -> None: db_client: DBConnection = self._job_client.db_client fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) - id_field_name: str = self._job_client.config.id_field_name - unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( - self._load_table - ) write_disposition: TWriteDisposition = cast( TWriteDisposition, self._load_table.get("write_disposition", "append") ) @@ -751,20 +740,21 @@ def run(self) -> None: with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - if self._load_table["name"] not in self._schema.dlt_table_names(): - arrow_table = generate_arrow_uuid_column( - arrow_table, - unique_identifiers=unique_identifiers, - table_name=fq_table_name, - id_field_name=id_field_name, - ) + # TODO: To function + merge_key = ( + get_columns_names_with_prop(self._load_table, "merge_key") + or get_columns_names_with_prop(self._load_table, "primary_key") + or get_columns_names_with_prop(self._load_table, "unique") + ) + if isinstance(merge_key, list) and len(merge_key) >= 1: + merge_key = merge_key[0] write_records( arrow_table, db_client=db_client, table_name=fq_table_name, write_disposition=write_disposition, - id_field_name=id_field_name, + merge_key=merge_key, ) @@ -827,16 +817,10 @@ def run(self) -> None: null_array = pa.array([None] * payload_arrow_table.num_rows, type=field.type) payload_arrow_table = payload_arrow_table.append_column(field, null_array) - # TODO: Remove special field, we don't need it. - payload_arrow_table = payload_arrow_table.append_column( - pa.field(self._job_client.id_field_name, pa.string()), - pa.array([""] * payload_arrow_table.num_rows, type=pa.string()), - ) - write_records( payload_arrow_table, db_client=db_client, - id_field_name=target_table_id_field_name, + merge_key=target_table_id_field_name, # type: ignore table_name=fq_table_name, write_disposition="merge", remove_orphans=True, diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index 6cc562038e..66542fe2b1 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -43,7 +43,6 @@ def make_arrow_table_schema( table_name: str, schema: Schema, type_mapper: TypeMapper, - id_field_name: Optional[str] = None, vector_field_name: Optional[str] = None, embedding_fields: Optional[List[str]] = None, embedding_model_func: Optional[TextEmbeddingFunction] = None, @@ -52,9 +51,6 @@ def make_arrow_table_schema( """Creates a PyArrow schema from a dlt schema.""" arrow_schema: List[TArrowField] = [] - if id_field_name: - arrow_schema.append(pa.field(id_field_name, pa.string())) - if embedding_fields: # User's provided dimension config, if provided, takes precedence. vec_size = embedding_model_dimensions or embedding_model_func.ndims() diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 4b964604e6..7405b8426e 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -12,7 +12,6 @@ from dlt.destinations.impl.lancedb.lancedb_adapter import ( lancedb_adapter, VECTORIZE_HINT, - DOCUMENT_ID_HINT, ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient from dlt.extract import DltResource @@ -52,18 +51,6 @@ def some_data() -> Generator[DictStrStr, Any, None]: "x-lancedb-embed": True, } - lancedb_adapter( - some_data, - document_id=["content"], - ) - - assert some_data.columns["content"] == { # type: ignore - "name": "content", - "data_type": "text", - "x-lancedb-embed": True, - "x-lancedb-doc-id": True, - } - def test_basic_state_and_schema() -> None: generator_instance1 = sequence_generator() @@ -157,6 +144,7 @@ def some_data() -> Generator[List[DictStrAny], Any, None]: info = pipeline.run( some_data(), ) + assert_load_info(info) assert_table(pipeline, "some_data", items=data) @@ -282,7 +270,7 @@ def test_pipeline_merge() -> None: def movies_data() -> Any: yield data - @dlt.resource(primary_key="doc_id", merge_key=["merge_id", "title"]) + @dlt.resource(primary_key=["doc_id", "merge_id", "title"], merge_key="doc_id") def movies_data_explicit_merge_keys() -> Any: yield data @@ -441,121 +429,22 @@ def test_merge_github_nested() -> None: def test_empty_dataset_allowed() -> None: # dataset_name is optional so dataset name won't be autogenerated when not explicitly passed. pipe = dlt.pipeline(destination="lancedb", dev_mode=True) - client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) # Dataset in load info is empty. assert info.dataset_name is None - client = pipe.destination_client() # type: ignore[assignment] - assert client.dataset_name is None - assert client.sentinel_table == "dltSentinelTable" + client = pipe.destination_client() + assert client.dataset_name is None # type: ignore + assert client.sentinel_table == "dltSentinelTable" # type: ignore assert_table(pipe, "content", expected_items_count=3) def test_merge_no_orphans() -> None: @dlt.resource( - write_disposition="merge", - primary_key=["doc_id"], - table_name="document", - ) - def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: - for doc in docs: - doc_id = doc["doc_id"] - chunks = chunk_document(doc["text"]) - embeddings = [ - { - "chunk_hash": digest128(chunk), - "chunk_text": chunk, - "embedding": mock_embed(), - } - for chunk in chunks - ] - yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} - - @dlt.source(max_table_nesting=1) - def documents_source( - docs: List[DictStrAny], - ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: - return documents(docs) - - pipeline = dlt.pipeline( - pipeline_name="chunked_docs", - destination="lancedb", - dataset_name="chunked_documents", - dev_mode=True, - ) - - initial_docs = [ - { - "text": ( - "This is the first document. It contains some text that will be chunked and" - " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" - ), - "doc_id": 1, - }, - { - "text": "Here's another document. It's a bit different from the first one.", - "doc_id": 2, - }, - ] - - info = pipeline.run(documents_source(initial_docs)) - assert_load_info(info) - - updated_docs = [ - { - "text": "This is the first document, but it has been updated with new content.", - "doc_id": 1, - }, - { - "text": "This is a completely new document that wasn't in the initial set.", - "doc_id": 3, - }, - ] - - info = pipeline.run(documents_source(updated_docs)) - assert_load_info(info) - - with pipeline.destination_client() as client: - # Orphaned chunks/documents must have been discarded. - # Shouldn't contain any text from `initial_docs' where doc_id=1. - expected_text = { - "Here's ano", - "ther docum", - "ent. It's ", - "a bit diff", - "erent from", - " the first", - " one.", - "This is th", - "e first do", - "cument, bu", - "t it has b", - "een update", - "d with new", - " content.", - "This is a ", - "completely", - " new docum", - "ent that w", - "asn't in t", - "he initial", - " set.", - } - - embeddings_table_name = client.make_qualified_table_name("document__embeddings") # type: ignore[attr-defined] - - tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] - df = tbl.to_pandas() - assert set(df["chunk_text"]) == expected_text - - -def test_merge_no_orphans_with_doc_id() -> None: - @dlt.resource( # type: ignore write_disposition="merge", table_name="document", - columns={"doc_id": {DOCUMENT_ID_HINT: True}}, + merge_key=["doc_id"], ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: @@ -647,55 +536,3 @@ def documents_source( tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] df = tbl.to_pandas() assert set(df["chunk_text"]) == expected_text - - -def test_primary_key_not_compatible_with_doc_id_hint_on_merge_disposition() -> None: - @dlt.resource( # type: ignore - write_disposition="merge", - table_name="document", - primary_key="doc_id", - columns={"doc_id": {DOCUMENT_ID_HINT: True}}, - ) - def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: - for doc in docs: - doc_id = doc["doc_id"] - chunks = chunk_document(doc["text"]) - embeddings = [ - { - "chunk_hash": digest128(chunk), - "chunk_text": chunk, - "embedding": mock_embed(), - } - for chunk in chunks - ] - yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} - - @dlt.source(max_table_nesting=1) - def documents_source( - docs: List[DictStrAny], - ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: - return documents(docs) - - pipeline = dlt.pipeline( - pipeline_name="test_mandatory_doc_id_hint_on_merge_disposition", - destination="lancedb", - dataset_name="test_mandatory_doc_id_hint_on_merge_disposition", - dev_mode=True, - ) - - initial_docs = [ - { - "text": ( - "This is the first document. It contains some text that will be chunked and" - " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" - ), - "doc_id": 1, - }, - { - "text": "Here's another document. It's a bit different from the first one.", - "doc_id": 2, - }, - ] - - with pytest.raises(PipelineStepFailed): - pipeline.run(documents(initial_docs)) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 6346171574..65489f6bda 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -6,13 +6,11 @@ from lancedb.table import Table # type: ignore from pandas import DataFrame from pandas.testing import assert_frame_equal -from pyarrow import Table import dlt from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id from dlt.destinations.impl.lancedb.lancedb_adapter import ( - DOCUMENT_ID_HINT, lancedb_adapter, ) from tests.load.lancedb.utils import chunk_document @@ -42,11 +40,7 @@ def test_lancedb_remove_orphaned_records() -> None: dev_mode=True, ) - @dlt.resource( # type: ignore[call-overload] - table_name="parent", - write_disposition="merge", - columns={"id": {DOCUMENT_ID_HINT: True}}, - ) + @dlt.resource(table_name="parent", write_disposition="merge", merge_key=["id"]) def identity_resource( data: List[DictStrAny], ) -> Generator[List[DictStrAny], None, None]: @@ -143,11 +137,11 @@ def test_lancedb_remove_orphaned_records_root_table() -> None: dev_mode=True, ) - @dlt.resource( # type: ignore[call-overload] + @dlt.resource( table_name="root", write_disposition="merge", - merge_key=["chunk_hash"], - columns={"doc_id": {DOCUMENT_ID_HINT: True}}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id"], ) def identity_resource( data: List[DictStrAny], @@ -201,7 +195,8 @@ def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> No @dlt.resource( write_disposition="merge", table_name="document", - merge_key=["chunk"], + primary_key=["doc_id", "chunk"], + merge_key=["doc_id"], ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: From 15a0cf666fe7fb3ece6f61d57bbdfb375952e2c3 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 27 Aug 2024 22:26:08 +0200 Subject: [PATCH 072/113] Infer merge key if not supplied from provided primary key Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 12 +++---- dlt/destinations/impl/lancedb/utils.py | 32 ++++++++++++++++++- tests/load/lancedb/test_pipeline.py | 22 ++----------- .../lancedb/test_remove_orphaned_records.py | 4 +-- 4 files changed, 39 insertions(+), 31 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 57639a139a..58974afc3b 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -75,6 +75,7 @@ from dlt.destinations.impl.lancedb.utils import ( set_non_standard_providers_environment_variables, get_default_arrow_value, + get_lancedb_merge_key, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -162,7 +163,7 @@ def write_records( db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", - merge_key: Optional[Union[str, List[str]]] = None, + merge_key: Optional[str] = None, remove_orphans: Optional[bool] = False, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -740,14 +741,9 @@ def run(self) -> None: with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - # TODO: To function merge_key = ( - get_columns_names_with_prop(self._load_table, "merge_key") - or get_columns_names_with_prop(self._load_table, "primary_key") - or get_columns_names_with_prop(self._load_table, "unique") + get_lancedb_merge_key(self._load_table) if write_disposition == "merge" else None ) - if isinstance(merge_key, list) and len(merge_key) >= 1: - merge_key = merge_key[0] write_records( arrow_table, @@ -820,7 +816,7 @@ def run(self) -> None: write_records( payload_arrow_table, db_client=db_client, - merge_key=target_table_id_field_name, # type: ignore + merge_key=target_table_id_field_name, table_name=fq_table_name, write_disposition="merge", remove_orphans=True, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index e1847120bd..04a1ecaa23 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -6,6 +6,8 @@ import pyarrow as pa import pyarrow.compute as pc +from dlt.common import logger +from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider @@ -63,7 +65,7 @@ def get_unique_identifiers_from_table_schema(table_schema: TTableSchema) -> List """ primary_keys = get_columns_names_with_prop(table_schema, "primary_key") merge_keys = [] - if table_schema.get("write_disposition")=="merge": + if table_schema.get("write_disposition") == "merge": merge_keys = get_columns_names_with_prop(table_schema, "merge_key") if join_keys := list(set(primary_keys + merge_keys)): return join_keys @@ -93,3 +95,31 @@ def get_default_arrow_value(field_type: TArrowDataType) -> object: return datetime.now() else: raise ValueError(f"Unsupported data type: {field_type}") + + +def get_lancedb_merge_key(load_table: TTableSchema) -> str: + if merge_key_ := get_columns_names_with_prop(load_table, "merge_key"): + if len(merge_key_) > 1: + DestinationTerminalException( + "LanceDB destination does not support compound merge keys." + ) + if primary_key := get_columns_names_with_prop(load_table, "primary_key"): + if merge_key_: + logger.warning( + "LanceDB destination currently does not yet support primary key constraints:" + " https://github.com/lancedb/lancedb/issues/1120. Only `merge_key` is supported." + " The supplied primary key will be ignored!" + ) + elif len(primary_key) == 1: + logger.warning( + "LanceDB destination currently does not yet support primary key constraints:" + " https://github.com/lancedb/lancedb/issues/1120. Only `merge_key` is supported." + " The primary key will be used as a proxy merge key! Please use `merge_key` instead." + ) + return primary_key[0] + if len(merge_key_) == 1: + return merge_key_[0] + elif len(unique_key := get_columns_names_with_prop(load_table, "unique")) == 1: + return unique_key[0] + else: + return "" diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 7405b8426e..fe83d3d7ff 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -15,7 +15,6 @@ ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient from dlt.extract import DltResource -from dlt.pipeline.exceptions import PipelineStepFailed from tests.load.lancedb.utils import assert_table, chunk_document, mock_embed from tests.load.utils import sequence_generator, drop_active_pipeline_data from tests.pipeline.utils import assert_load_info @@ -127,7 +126,7 @@ def test_explicit_append() -> None: {"doc_id": 3, "content": "3"}, ] - @dlt.resource(primary_key="doc_id") + @dlt.resource(merge_key="doc_id") def some_data() -> Generator[List[DictStrAny], Any, None]: yield data @@ -266,11 +265,11 @@ def test_pipeline_merge() -> None: }, ] - @dlt.resource(primary_key="doc_id") + @dlt.resource() def movies_data() -> Any: yield data - @dlt.resource(primary_key=["doc_id", "merge_id", "title"], merge_key="doc_id") + @dlt.resource(merge_key="doc_id") def movies_data_explicit_merge_keys() -> Any: yield data @@ -307,21 +306,6 @@ def movies_data_explicit_merge_keys() -> Any: assert_load_info(info) assert_table(pipeline, "movies_data", items=data) - info = pipeline.run( - movies_data(), - write_disposition="merge", - ) - assert_load_info(info) - assert_table(pipeline, "movies_data", items=data) - - # Test with explicit merge keys. - info = pipeline.run( - movies_data_explicit_merge_keys(), - write_disposition="merge", - ) - assert_load_info(info) - assert_table(pipeline, "movies_data_explicit_merge_keys", items=data) - def test_pipeline_with_schema_evolution() -> None: data = [ diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 65489f6bda..40b084d430 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -140,7 +140,6 @@ def test_lancedb_remove_orphaned_records_root_table() -> None: @dlt.resource( table_name="root", write_disposition="merge", - primary_key=["doc_id", "chunk_hash"], merge_key=["doc_id"], ) def identity_resource( @@ -195,8 +194,7 @@ def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> No @dlt.resource( write_disposition="merge", table_name="document", - primary_key=["doc_id", "chunk"], - merge_key=["doc_id"], + merge_key="doc_id", ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: From e9462e373b877309e4dc9f920fc7e30c0849513d Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 27 Aug 2024 22:29:06 +0200 Subject: [PATCH 073/113] Remove unused utility functions Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/utils.py | 58 ++------------------------ 1 file changed, 3 insertions(+), 55 deletions(-) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 04a1ecaa23..7898bd4e01 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,10 +1,8 @@ import os -import uuid from datetime import date, datetime -from typing import Sequence, Union, Dict, List +from typing import Union, Dict import pyarrow as pa -import pyarrow.compute as pc from dlt.common import logger from dlt.common.destination.exceptions import DestinationTerminalException @@ -22,57 +20,6 @@ } -# TODO: Update `generate_arrow_uuid_column` when pyarrow 17.0.0 becomes available with vectorized operations (batched + memory-mapped) -def generate_arrow_uuid_column( - table: pa.Table, unique_identifiers: Sequence[str], id_field_name: str, table_name: str -) -> pa.Table: - """Generates deterministic UUID - used for deduplication, returning a new arrow - table with added UUID column. - - Args: - table (pa.Table): PyArrow table to generate UUIDs for. - unique_identifiers (Sequence[str]): A list of unique identifier column names. - id_field_name (str): Name of the new UUID column. - table_name (str): Name of the table. - - Returns: - pa.Table: New PyArrow table with the new UUID column. - """ - - unique_identifiers_columns = [] - for col in unique_identifiers: - column = pc.fill_null(pc.cast(table[col], pa.string()), "") - unique_identifiers_columns.append(column.to_pylist()) - - uuids = pa.array( - [ - str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) - for x in ["".join(x) for x in zip(*unique_identifiers_columns)] - ] - ) - - return table.append_column(id_field_name, uuids) - - -def get_unique_identifiers_from_table_schema(table_schema: TTableSchema) -> List[str]: - """Returns a list of merge keys for a table used for either merging or deduplication. - - Args: - table_schema (TTableSchema): a dlt table schema. - - Returns: - Sequence[str]: A list of unique column identifiers. - """ - primary_keys = get_columns_names_with_prop(table_schema, "primary_key") - merge_keys = [] - if table_schema.get("write_disposition") == "merge": - merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys)): - return join_keys - else: - return get_columns_names_with_prop(table_schema, "unique") - - def set_non_standard_providers_environment_variables( embedding_model_provider: TEmbeddingProvider, api_key: Union[str, None] ) -> None: @@ -114,7 +61,8 @@ def get_lancedb_merge_key(load_table: TTableSchema) -> str: logger.warning( "LanceDB destination currently does not yet support primary key constraints:" " https://github.com/lancedb/lancedb/issues/1120. Only `merge_key` is supported." - " The primary key will be used as a proxy merge key! Please use `merge_key` instead." + " The primary key will be used as a proxy merge key! Please use `merge_key`" + " instead." ) return primary_key[0] if len(merge_key_) == 1: From 8af98d7a286517a5bed88a07c510f2e0458c7f75 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 27 Aug 2024 23:06:05 +0200 Subject: [PATCH 074/113] Remove LanceDB doc ID hints and use schema normalizer Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 58974afc3b..3805afe238 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -442,7 +442,7 @@ def extend_lancedb_table_schema(self, table_name: str, field_schemas: List[pa.Fi try: # Use DataFusion SQL syntax to alter fields without loading data into client memory. - # Currently, the most efficient way to modify column values is in LanceDB. + # Now, the most efficient way to modify column values is in LanceDB. new_fields = { field.name: f"CAST(NULL AS {arrow_datatype_to_fusion_datatype(field.type)})" for field in field_schemas @@ -489,7 +489,6 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: else: embedding_fields = None vector_field_name = None - id_field_name = None embedding_model_func = None embedding_model_dimensions = None @@ -551,7 +550,9 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: # normalize property names p_load_id = self.schema.naming.normalize_identifier("load_id") - p_dlt_load_id = self.schema.naming.normalize_identifier("_dlt_load_id") + p_dlt_load_id = self.schema.naming.normalize_identifier( + self.schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] + ) p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") p_status = self.schema.naming.normalize_identifier("status") p_version = self.schema.naming.normalize_identifier("version") @@ -783,10 +784,11 @@ def run(self) -> None: fq_table_name = self._job_client.make_qualified_table_name(job.table_name) if target_is_root_table: - target_table_id_field_name = "_dlt_id" + target_table_id_field_name = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] file_path = job.file_path else: - target_table_id_field_name = "_dlt_parent_id" + # This should look for root NOT parent. more efficient! + target_table_id_field_name = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined] file_path = self.get_parent_path(table_lineage, job.table_schema.get("parent")) with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: From 4195bb4feed75e0b0075a021ad3c39871e970e1b Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 28 Aug 2024 00:13:11 +0200 Subject: [PATCH 075/113] LanceDB writes strange code Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_adapter.py | 22 ++++++- .../impl/lancedb/lancedb_client.py | 2 +- dlt/destinations/impl/lancedb/utils.py | 56 +++++++++--------- ...move_orphaned_records.py => test_merge.py} | 57 ++++++++++++++++++- tests/load/lancedb/test_pipeline.py | 16 +++++- 5 files changed, 119 insertions(+), 34 deletions(-) rename tests/load/lancedb/{test_remove_orphaned_records.py => test_merge.py} (81%) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index df356430e4..8daef77f68 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -11,6 +11,7 @@ def lancedb_adapter( data: Any, embed: TColumnNames = None, + merge_key: TColumnNames = None, ) -> DltResource: """Prepares data for the LanceDB destination by specifying which columns should be embedded. @@ -20,6 +21,8 @@ def lancedb_adapter( object. embed (TColumnNames, optional): Specify columns to generate embeddings for. It can be a single column name as a string, or a list of column names. + merge_key (TColumnNames, optional): Specify columns to merge on. + It can be a single column name as a string, or a list of column names. Returns: DltResource: A resource with applied LanceDB-specific hints. @@ -50,8 +53,25 @@ def lancedb_adapter( VECTORIZE_HINT: True, # type: ignore[misc] } + if merge_key: + if isinstance(merge_key, str): + merge_key = [merge_key] + if not isinstance(merge_key, list): + raise ValueError( + "'merge_key' must be a list of column names or a single column name as a string." + ) + + for column_name in merge_key: + column_hints[column_name] = { + "name": column_name, + "merge_key": True, + } + if not column_hints: - raise ValueError("You must must provide the 'embed' argument if using the adapter.") + raise ValueError( + "You must must provide at least either the 'embed' or 'merge_key' argument if using the" + " adapter." + ) else: resource.apply_hints(columns=column_hints) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 3805afe238..92de8a984b 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -163,7 +163,7 @@ def write_records( db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", - merge_key: Optional[str] = None, + merge_key: Optional[Union[str, List[str]]] = None, remove_orphans: Optional[bool] = False, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 7898bd4e01..885b24eebd 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,10 +1,9 @@ import os from datetime import date, datetime -from typing import Union, Dict +from typing import Union, Dict, Optional, TypeVar, Generic, Iterable, Iterator import pyarrow as pa -from dlt.common import logger from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop @@ -44,30 +43,31 @@ def get_default_arrow_value(field_type: TArrowDataType) -> object: raise ValueError(f"Unsupported data type: {field_type}") -def get_lancedb_merge_key(load_table: TTableSchema) -> str: - if merge_key_ := get_columns_names_with_prop(load_table, "merge_key"): - if len(merge_key_) > 1: - DestinationTerminalException( - "LanceDB destination does not support compound merge keys." - ) - if primary_key := get_columns_names_with_prop(load_table, "primary_key"): - if merge_key_: - logger.warning( - "LanceDB destination currently does not yet support primary key constraints:" - " https://github.com/lancedb/lancedb/issues/1120. Only `merge_key` is supported." - " The supplied primary key will be ignored!" - ) - elif len(primary_key) == 1: - logger.warning( - "LanceDB destination currently does not yet support primary key constraints:" - " https://github.com/lancedb/lancedb/issues/1120. Only `merge_key` is supported." - " The primary key will be used as a proxy merge key! Please use `merge_key`" - " instead." - ) - return primary_key[0] - if len(merge_key_) == 1: - return merge_key_[0] - elif len(unique_key := get_columns_names_with_prop(load_table, "unique")) == 1: - return unique_key[0] +ItemType = TypeVar('ItemType') + + +# LanceDB `merge_insert` expects an 'iter()' method instead of using standard iteration. +# https://github.com/lancedb/lancedb/blob/ae85008714792a6b724c75793b63273c51caba88/python/python/lancedb/table.py#L2264 +class IterableWrapper(Generic[ItemType]): + def __init__(self, iterable: Iterable[ItemType]) -> None: + self.iterable = iterable + + def __iter__(self) -> Iterator[ItemType]: + return iter(self.iterable) + + def iter(self) -> Iterator[ItemType]: + return iter(self.iterable) + + +def get_lancedb_merge_key(load_table: TTableSchema) -> Optional[Union[str, IterableWrapper[str]]]: + if get_columns_names_with_prop(load_table, "primary_key"): + raise DestinationTerminalException( + "LanceDB destination currently does not yet support primary key constraints: " + "https://github.com/lancedb/lancedb/issues/1120. Only `merge_key` is supported!" + ) + if merge_key := get_columns_names_with_prop(load_table, "merge_key"): + return merge_key[0] if len(merge_key)==1 else IterableWrapper(merge_key) + elif unique_key := get_columns_names_with_prop(load_table, "unique"): + return unique_key[0] if len(unique_key)==1 else IterableWrapper(unique_key) else: - return "" + return None diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_merge.py similarity index 81% rename from tests/load/lancedb/test_remove_orphaned_records.py rename to tests/load/lancedb/test_merge.py index 40b084d430..7942bd8404 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_merge.py @@ -140,7 +140,7 @@ def test_lancedb_remove_orphaned_records_root_table() -> None: @dlt.resource( table_name="root", write_disposition="merge", - merge_key=["doc_id"], + merge_key=["doc_id", "chunk_hash"], ) def identity_resource( data: List[DictStrAny], @@ -262,3 +262,58 @@ def documents_source( for _, vector in enumerate(df["vector__"]): assert isinstance(vector, np.ndarray) assert vector.size > 0 + + +def test_lancedb_compound_merge_key_root_table() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_compound_merge_key", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition="merge", + merge_key=["doc_id", "chunk_hash"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + run_1 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash", "foo"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash", "foo"]).reset_index(drop=True) + )[["doc_id", "chunk_hash", "foo"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index fe83d3d7ff..77d70e075e 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -50,6 +50,18 @@ def some_data() -> Generator[DictStrStr, Any, None]: "x-lancedb-embed": True, } + lancedb_adapter( + some_data, + merge_key=["content"], + ) + + assert some_data.columns["content"] == { # type: ignore + "name": "content", + "data_type": "text", + "x-lancedb-embed": True, + "merge_key": True, + } + def test_basic_state_and_schema() -> None: generator_instance1 = sequence_generator() @@ -119,7 +131,6 @@ def some_data() -> Generator[DictStrStr, Any, None]: def test_explicit_append() -> None: - """Append should work even when the primary key is specified.""" data = [ {"doc_id": 1, "content": "1"}, {"doc_id": 2, "content": "2"}, @@ -375,10 +386,9 @@ def test_merge_github_nested() -> None: data = json.load(f) info = pipe.run( - lancedb_adapter(data[:17], embed=["title", "body"]), + lancedb_adapter(data[:17], embed=["title", "body"], merge_key="id"), table_name="issues", write_disposition="merge", - primary_key="id", ) assert_load_info(info) # assert if schema contains tables with right names From 2573d3a758e386d47bee14e8771c409de0feea6b Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 28 Aug 2024 00:25:44 +0200 Subject: [PATCH 076/113] Minor Formatting Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 3 ++- dlt/destinations/impl/lancedb/schema.py | 6 +++--- dlt/destinations/impl/lancedb/utils.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 92de8a984b..8ef66e0135 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -76,6 +76,7 @@ set_non_standard_providers_environment_variables, get_default_arrow_value, get_lancedb_merge_key, + IterableWrapper, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -163,7 +164,7 @@ def write_records( db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", - merge_key: Optional[Union[str, List[str]]] = None, + merge_key: Optional[Union[str, IterableWrapper[str]]] = None, remove_orphans: Optional[bool] = False, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index 66542fe2b1..2fa3251ede 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -3,7 +3,8 @@ from typing import ( List, cast, - Optional, Tuple, + Optional, + Tuple, ) import pyarrow as pa @@ -21,7 +22,7 @@ TArrowField: TypeAlias = pa.Field NULL_SCHEMA: TArrowSchema = pa.schema([]) """Empty pyarrow Schema with no fields.""" -TableJob = namedtuple('TableJob', ['table_schema', 'table_name', 'file_path']) +TableJob = namedtuple("TableJob", ["table_schema", "table_name", "file_path"]) TTableLineage: TypeAlias = List[TableJob] @@ -99,4 +100,3 @@ def arrow_datatype_to_fusion_datatype(arrow_type: TArrowSchema) -> str: return "TIMESTAMP" return type_map.get(arrow_type, "UNKNOWN") - diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 885b24eebd..f335c9009b 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -43,7 +43,7 @@ def get_default_arrow_value(field_type: TArrowDataType) -> object: raise ValueError(f"Unsupported data type: {field_type}") -ItemType = TypeVar('ItemType') +ItemType = TypeVar("ItemType") # LanceDB `merge_insert` expects an 'iter()' method instead of using standard iteration. @@ -66,8 +66,8 @@ def get_lancedb_merge_key(load_table: TTableSchema) -> Optional[Union[str, Itera "https://github.com/lancedb/lancedb/issues/1120. Only `merge_key` is supported!" ) if merge_key := get_columns_names_with_prop(load_table, "merge_key"): - return merge_key[0] if len(merge_key)==1 else IterableWrapper(merge_key) + return merge_key[0] if len(merge_key) == 1 else IterableWrapper(merge_key) elif unique_key := get_columns_names_with_prop(load_table, "unique"): - return unique_key[0] if len(unique_key)==1 else IterableWrapper(unique_key) + return unique_key[0] if len(unique_key) == 1 else IterableWrapper(unique_key) else: return None From 86c198c20064a8f18c488b9d03de4b5165371599 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 28 Aug 2024 16:48:48 +0200 Subject: [PATCH 077/113] Support compound primary and merge keys Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 2 + .../impl/lancedb/lancedb_client.py | 11 +-- dlt/destinations/impl/lancedb/utils.py | 16 ++-- tests/load/lancedb/test_merge.py | 75 +++++++++++++++++-- tests/load/lancedb/test_pipeline.py | 3 +- 5 files changed, 90 insertions(+), 17 deletions(-) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index d9b92e02b9..cd792b6cee 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -32,6 +32,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.recommended_file_size = 128_000_000 + caps.supported_merge_strategies = ["delete-insert", "upsert"] + return caps @property diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8ef66e0135..315028f871 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -51,7 +51,7 @@ TWriteDisposition, TColumnSchema, ) -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.schema.utils import get_columns_names_with_prop, DEFAULT_MERGE_STRATEGY from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -704,7 +704,10 @@ def create_table_chain_completed_followup_jobs( jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) - if table_chain[0].get("write_disposition") == "merge": + # Orphan removal is only supported for upsert strategy because we need a deterministic key hash. + first_table_in_chain = table_chain[0] + merge_strategy = first_table_in_chain.get("x-merge-strategy", DEFAULT_MERGE_STRATEGY) + if first_table_in_chain.get("write_disposition") == "merge" and merge_strategy == "upsert": all_job_paths_ordered = [ job.file_path for table in table_chain @@ -743,9 +746,7 @@ def run(self) -> None: with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - merge_key = ( - get_lancedb_merge_key(self._load_table) if write_disposition == "merge" else None - ) + merge_key = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] write_records( arrow_table, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index f335c9009b..735500f9f1 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -4,7 +4,7 @@ import pyarrow as pa -from dlt.common.destination.exceptions import DestinationTerminalException +from dlt.common import logger from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider @@ -60,14 +60,18 @@ def iter(self) -> Iterator[ItemType]: def get_lancedb_merge_key(load_table: TTableSchema) -> Optional[Union[str, IterableWrapper[str]]]: - if get_columns_names_with_prop(load_table, "primary_key"): - raise DestinationTerminalException( - "LanceDB destination currently does not yet support primary key constraints: " - "https://github.com/lancedb/lancedb/issues/1120. Only `merge_key` is supported!" - ) if merge_key := get_columns_names_with_prop(load_table, "merge_key"): return merge_key[0] if len(merge_key) == 1 else IterableWrapper(merge_key) + elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): + # No merge key defined, warn and merge on the primary key. + logger.warning( + "Merge strategy selected without defined merge key - using primary key as merge key." + ) + return primary_key[0] if len(primary_key) == 1 else IterableWrapper(merge_key) elif unique_key := get_columns_names_with_prop(load_table, "unique"): + logger.warning( + "Merge strategy selected without defined merge key - using unique key as merge key." + ) return unique_key[0] if len(unique_key) == 1 else IterableWrapper(unique_key) else: return None diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py index 7942bd8404..120e6f4233 100644 --- a/tests/load/lancedb/test_merge.py +++ b/tests/load/lancedb/test_merge.py @@ -40,7 +40,12 @@ def test_lancedb_remove_orphaned_records() -> None: dev_mode=True, ) - @dlt.resource(table_name="parent", write_disposition="merge", merge_key=["id"]) + @dlt.resource( + table_name="parent", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key="id", + merge_key="id", + ) def identity_resource( data: List[DictStrAny], ) -> Generator[List[DictStrAny], None, None]: @@ -139,8 +144,9 @@ def test_lancedb_remove_orphaned_records_root_table() -> None: @dlt.resource( table_name="root", - write_disposition="merge", - merge_key=["doc_id", "chunk_hash"], + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id"], ) def identity_resource( data: List[DictStrAny], @@ -192,8 +198,9 @@ def identity_resource( def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: @dlt.resource( - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, table_name="document", + primary_key="doc_id", merge_key="doc_id", ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: @@ -257,7 +264,7 @@ def documents_source( df = tbl.to_pandas() # Check (non-empty) embeddings as present, and that orphaned embeddings have been discarded. - assert len(df) == 21 + assert len(df)==21 assert "vector__" in df.columns for _, vector in enumerate(df["vector__"]): assert isinstance(vector, np.ndarray) @@ -275,6 +282,63 @@ def test_lancedb_compound_merge_key_root_table() -> None: @dlt.resource( table_name="root", write_disposition="merge", + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id", "chunk_hash"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + run_1 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, + {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash", "foo"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash", "foo"]).reset_index(drop=True) + )[["doc_id", "chunk_hash", "foo"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_lancedb_compound_merge_key_root_table() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_compound_merge_key", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition="merge", + primary_key=["doc_id", "chunk_hash"], merge_key=["doc_id", "chunk_hash"], ) def identity_resource( @@ -301,6 +365,7 @@ def identity_resource( pd.DataFrame( data=[ {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, + {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, ] diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 77d70e075e..4030f92f78 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -436,8 +436,9 @@ def test_empty_dataset_allowed() -> None: def test_merge_no_orphans() -> None: @dlt.resource( - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, table_name="document", + primary_key=["doc_id", "chunk_hash"], merge_key=["doc_id"], ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: From aa03930fbf426a3ba48aa1fbf50e04e59be3dbcf Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 28 Aug 2024 16:53:57 +0200 Subject: [PATCH 078/113] Remove old comment Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 315028f871..50e11e0b5b 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -789,7 +789,6 @@ def run(self) -> None: target_table_id_field_name = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] file_path = job.file_path else: - # This should look for root NOT parent. more efficient! target_table_id_field_name = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined] file_path = self.get_parent_path(table_lineage, job.table_schema.get("parent")) From d1e417333e4396394ec6a9f3e2221e71f3bbd5dd Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 29 Aug 2024 13:33:46 +0200 Subject: [PATCH 079/113] - Change default vector column name to "vector" to conform with lancedb standard - Add search tests with tantivy as search engine Signed-off-by: Marcel Coetzee --- .../impl/lancedb/configuration.py | 2 +- poetry.lock | 44 ++++++++++- pyproject.toml | 3 +- tests/load/lancedb/test_pipeline.py | 73 ++++++++++++++++++- 4 files changed, 116 insertions(+), 6 deletions(-) diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index ba3a8b49d9..329132f495 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -90,7 +90,7 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): but it is configurable in rare cases. Make sure it corresponds with the associated embedding model's dimensionality.""" - vector_field_name: str = "vector__" + vector_field_name: str = "vector" """Name of the special field to store the vector embeddings.""" id_field_name: str = "id__" """Name of the special field to manage deduplication.""" diff --git a/poetry.lock b/poetry.lock index 230b354b97..1bfdb776a2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "about-time" @@ -8647,6 +8647,44 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tantivy" +version = "0.22.0" +description = "" +optional = true +python-versions = ">=3.8" +files = [ + {file = "tantivy-0.22.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:732ec74c4dd531253af4c14756b7650527f22c7fab244e83b42d76a0a1437219"}, + {file = "tantivy-0.22.0-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:bf1da07b7e1003af4260b1ef3c3db7cb05db1578606092a6ca7a3cff2a22858a"}, + {file = "tantivy-0.22.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689ed52985e914c531eadd8dd2df1b29f0fa684687b6026206dbdc57cf9297b2"}, + {file = "tantivy-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5f2885c8e98d1efcc4836c3e9d327d6ba2bc6b5e2cd8ac9b0356af18f571070"}, + {file = "tantivy-0.22.0-cp310-none-win_amd64.whl", hash = "sha256:4543cc72f4fec30f50fed5cd503c13d0da7cffda47648c7b72c1759103309e41"}, + {file = "tantivy-0.22.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:ec693abf38f229bc1361b0d34029a8bb9f3ee5bb956a3e745e0c4a66ea815bec"}, + {file = "tantivy-0.22.0-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:e385839badc12b81e38bf0a4d865ee7c3a992fea9f5ce4117adae89369e7d1eb"}, + {file = "tantivy-0.22.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6c097d94be1af106676c86c02b185f029484fdbd9a2b9f17cb980e840e7bdad"}, + {file = "tantivy-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c47a5cdec306ea8594cb6e7effd4b430932ebfd969f9e8f99e343adf56a79bc9"}, + {file = "tantivy-0.22.0-cp311-none-win_amd64.whl", hash = "sha256:ba0ca878ed025d79edd9c51cda80b0105be8facbaec180fea64a17b80c74e7db"}, + {file = "tantivy-0.22.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:925682f3acb65c85c2a5a5b131401b9f30c184ea68aa73a8cc7c2ea6115e8ae3"}, + {file = "tantivy-0.22.0-cp312-cp312-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d75760e45a329313001354d6ca415ff12d9d812343792ae133da6bfbdc4b04a5"}, + {file = "tantivy-0.22.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd909d122b5af457d955552c304f8d5d046aee7024c703c62652ad72af89f3c7"}, + {file = "tantivy-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99266ffb204721eb2bd5b3184aa87860a6cff51b4563f808f78fa22d85a8093"}, + {file = "tantivy-0.22.0-cp312-none-win_amd64.whl", hash = "sha256:9ed6b813a1e7769444e33979b46b470b2f4c62d983c2560ce9486fb9be1491c9"}, + {file = "tantivy-0.22.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:97eb05f8585f321dbc733b64e7e917d061dc70c572c623730b366c216540d149"}, + {file = "tantivy-0.22.0-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:cc74748b6b886475c12bf47c8814861b79f850fb8a528f37ae0392caae0f6f14"}, + {file = "tantivy-0.22.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7059c51c25148e07a20bd73efc8b51c015c220f141f3638489447b99229c8c0"}, + {file = "tantivy-0.22.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f88d05f55e2c3e581de70c5c7f46e94e5869d1c0fd48c5db33be7e56b6b88c9a"}, + {file = "tantivy-0.22.0-cp38-none-win_amd64.whl", hash = "sha256:09bf6de2fa08aac1a7133bee3631c1123de05130fd2991ceb101f2abac51b9d2"}, + {file = "tantivy-0.22.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:9de1a7497d377477dc09029c343eb9106c2c5fdb2e399f8dddd624cd9c7622a2"}, + {file = "tantivy-0.22.0-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:e81e47edd0faffb5ad20f52ae75c3a2ed680f836e72bc85c799688d3a2557502"}, + {file = "tantivy-0.22.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27333518dbc309299dafe79443ee80eede5526a489323cdb0506b95eb334f985"}, + {file = "tantivy-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c9452d05e42450be53a9a58a9cf13f9ff8d3605c73bdc38a34ce5e167a25d77"}, + {file = "tantivy-0.22.0-cp39-none-win_amd64.whl", hash = "sha256:51e4ec0d44637562bf23912d18d12850c4b3176c0719e7b019d43b59199a643c"}, + {file = "tantivy-0.22.0.tar.gz", hash = "sha256:dce07fa2910c94934aa3d96c91087936c24e4a5802d839625d67edc6d1c95e5c"}, +] + +[package.extras] +dev = ["nox"] + [[package]] name = "tblib" version = "2.0.0" @@ -9669,7 +9707,7 @@ duckdb = ["duckdb"] filesystem = ["botocore", "s3fs"] gcp = ["gcsfs", "google-cloud-bigquery", "grpcio"] gs = ["gcsfs"] -lancedb = ["lancedb", "pyarrow"] +lancedb = ["lancedb", "pyarrow", "tantivy"] motherduck = ["duckdb", "pyarrow"] mssql = ["pyodbc"] parquet = ["pyarrow"] @@ -9684,4 +9722,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "2b8d00f91f33a380b2399989dcac0d1d106d0bd2cd8865c5b7e27a19885753b5" +content-hash = "888e1760984e867fde690a1cca90330e255d69a8775c81020d003650def7ab4c" diff --git a/pyproject.toml b/pyproject.toml index d32285572f..1bdaf77b86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ databricks-sql-connector = {version = ">=2.9.3", optional = true} clickhouse-driver = { version = ">=0.2.7", optional = true } clickhouse-connect = { version = ">=0.7.7", optional = true } lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'", allow-prereleases = true } +tantivy = { version = ">= 0.22.0", optional = true } deltalake = { version = ">=0.19.0", optional = true } [tool.poetry.extras] @@ -105,7 +106,7 @@ qdrant = ["qdrant-client"] databricks = ["databricks-sql-connector"] clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"] dremio = ["pyarrow"] -lancedb = ["lancedb", "pyarrow"] +lancedb = ["lancedb", "pyarrow", "tantivy"] deltalake = ["deltalake", "pyarrow"] diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index e817a2f6c8..66fd1c180c 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,6 +1,7 @@ -from typing import Iterator, Generator, Any, List +from typing import Iterator, Generator, Any, List, Mapping import pytest +from lancedb import DBConnection import dlt from dlt.common import json @@ -433,3 +434,73 @@ def test_empty_dataset_allowed() -> None: assert client.dataset_name is None assert client.sentinel_table == "dltSentinelTable" assert_table(pipe, "content", expected_items_count=3) + + +search_data = [ + {"text": "Frodo was a happy puppy"}, + {"text": "There are several kittens playing"}, +] + + +def test_fts_query() -> None: + @dlt.resource + def search_data_resource() -> Generator[Mapping[str, object], Any, None]: + yield from search_data + + pipeline = dlt.pipeline( + pipeline_name="test_fts_query", + destination="lancedb", + dataset_name=f"test_pipeline_append{uniq_id()}", + ) + info = pipeline.run( + search_data_resource(), + ) + assert_load_info(info) + + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore[assignment] + db_client: DBConnection = client.db_client + + table_name = client.make_qualified_table_name("search_data_resource") + tbl = db_client[table_name] + tbl.checkout_latest() + + tbl.create_fts_index("text") + results = tbl.search("kittens", query_type="fts").select(["text"]).to_list() + assert results[0]["text"] == "There are several kittens playing" + + +def test_semantic_query() -> None: + @dlt.resource + def search_data_resource() -> Generator[Mapping[str, object], Any, None]: + yield from search_data + + lancedb_adapter( + search_data_resource, + embed=["text"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_fts_query", + destination="lancedb", + dataset_name=f"test_pipeline_append{uniq_id()}", + ) + info = pipeline.run( + search_data_resource(), + ) + assert_load_info(info) + + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore[assignment] + db_client: DBConnection = client.db_client + + table_name = client.make_qualified_table_name("search_data_resource") + tbl = db_client[table_name] + tbl.checkout_latest() + + results = ( + tbl.search("puppy", query_type="vector", ordering_field_name="_distance") + .select(["text"]) + .to_list() + ) + assert results[0]["text"] == "Frodo was a happy puppy" From 613f5bcbe26709a16a2a3295b91ccddad2ab38b9 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 29 Aug 2024 13:38:17 +0200 Subject: [PATCH 080/113] Format and fix linting Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_pipeline.py | 4 ++-- tests/load/lancedb/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 66fd1c180c..e8bc5aa8bf 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,7 +1,7 @@ from typing import Iterator, Generator, Any, List, Mapping import pytest -from lancedb import DBConnection +from lancedb import DBConnection # type: ignore import dlt from dlt.common import json @@ -22,7 +22,7 @@ @pytest.fixture(autouse=True) -def drop_lancedb_data() -> Iterator[None]: +def drop_lancedb_data() -> Iterator[Any]: yield drop_active_pipeline_data() diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index dc3ea5304b..7431e895b7 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -52,7 +52,7 @@ def assert_table( "_dlt_id", "_dlt_load_id", dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", - dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector", ] objects_without_dlt_or_special_keys = [ {k: v for k, v in record.items() if k not in drop_keys} for record in records From 703c4a87569360b4b1d28113f930d2104ee4850f Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 29 Aug 2024 16:36:55 +0200 Subject: [PATCH 081/113] Add custom embedding function registration test Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_pipeline.py | 48 ++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index e8bc5aa8bf..3904dcdb1a 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,7 +1,9 @@ from typing import Iterator, Generator, Any, List, Mapping +import lancedb # type: ignore import pytest -from lancedb import DBConnection # type: ignore +from lancedb import DBConnection +from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore import dlt from dlt.common import json @@ -504,3 +506,47 @@ def search_data_resource() -> Generator[Mapping[str, object], Any, None]: .to_list() ) assert results[0]["text"] == "Frodo was a happy puppy" + + +def test_semantic_query_custom_embedding_functions_registered() -> None: + """Test the LanceDB registry registered custom embedding functions defined in models, if any. + See: https://github.com/dlt-hub/dlt/issues/1765""" + + @dlt.resource + def search_data_resource() -> Generator[Mapping[str, object], Any, None]: + yield from search_data + + lancedb_adapter( + search_data_resource, + embed=["text"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_fts_query", + destination="lancedb", + dataset_name=f"test_pipeline_append{uniq_id()}", + ) + info = pipeline.run( + search_data_resource(), + ) + assert_load_info(info) + + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore[assignment] + db_client_uri = client.db_client.uri + table_name = client.make_qualified_table_name("search_data_resource") + + # A new python process doesn't seem to correctly deserialize the custom embedding functions into global __REGISTRY__. + EmbeddingFunctionRegistry.get_instance().reset() + + # Must read into __REGISTRY__ here. + db = lancedb.connect(db_client_uri) + tbl = db[table_name] + tbl.checkout_latest() + + results = ( + tbl.search("puppy", query_type="vector", ordering_field_name="_distance") + .select(["text"]) + .to_list() + ) + assert results[0]["text"] == "Frodo was a happy puppy" From c07c8fcf8f74f05719d6e20d896d76ab3b55ae00 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 29 Aug 2024 22:39:22 +0200 Subject: [PATCH 082/113] Spawn process in test to make sure registry can be deserialized from arrow files Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 34 +++++++++---------- dlt/destinations/impl/lancedb/models.py | 34 ------------------- tests/load/lancedb/test_pipeline.py | 18 +++++++--- 3 files changed, 31 insertions(+), 55 deletions(-) delete mode 100644 dlt/destinations/impl/lancedb/models.py diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 78a37952b9..f28fdac78d 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,7 +1,6 @@ import uuid from types import TracebackType from typing import ( - ClassVar, List, Any, cast, @@ -37,7 +36,6 @@ RunnableLoadJob, StorageSchemaInfo, StateInfo, - TLoadJobState, LoadJob, ) from dlt.common.pendulum import timedelta @@ -70,7 +68,6 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -81,6 +78,7 @@ TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +EMPTY_STRING_PLACEHOLDER = "__EMPTY_STRING_PLACEHOLDER__" class LanceDBTypeMapper(TypeMapper): @@ -233,20 +231,11 @@ def __init__( embedding_model_provider, self.config.credentials.embedding_model_provider_api_key, ) - # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider == "openai": - from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings - - self.model_func = PatchedOpenAIEmbeddings( - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) - else: - self.model_func = self.registry.get(embedding_model_provider).create( - name=self.config.embedding_model, - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) + self.model_func = self.registry.get(embedding_model_provider).create( + name=self.config.embedding_model, + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) self.vector_field_name = self.config.vector_field_name self.id_field_name = self.config.id_field_name @@ -731,6 +720,17 @@ def run(self) -> None: with FileStorage.open_zipsafe_ro(self._file_path) as f: records: List[DictStrAny] = [json.loads(line) for line in f] + # Replace empty strings with placeholder string if OpenAI is used. + # https://github.com/lancedb/lancedb/issues/1577#issuecomment-2318104218. + if (self._job_client.config.embedding_model_provider == "openai") and ( + source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) + ): + records: List[Dict[str, Any]] + for record in records: + for k, v in record.items(): + if k in source_columns and not v: + record[k] = EMPTY_STRING_PLACEHOLDER + if self._load_table not in self._schema.dlt_tables(): for record in records: # Add reserved ID fields. diff --git a/dlt/destinations/impl/lancedb/models.py b/dlt/destinations/impl/lancedb/models.py deleted file mode 100644 index d90adb62bd..0000000000 --- a/dlt/destinations/impl/lancedb/models.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Union, List - -import numpy as np -from lancedb.embeddings import OpenAIEmbeddings # type: ignore -from lancedb.embeddings.registry import register # type: ignore -from lancedb.embeddings.utils import TEXT # type: ignore - - -@register("openai_patched") -class PatchedOpenAIEmbeddings(OpenAIEmbeddings): - EMPTY_STRING_PLACEHOLDER: str = "___EMPTY___" - - def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: # type: ignore[type-arg] - """ - Replace empty strings with a placeholder value. - """ - - sanitized_texts = super().sanitize_input(texts) - return [self.EMPTY_STRING_PLACEHOLDER if item == "" else item for item in sanitized_texts] - - def generate_embeddings( - self, - texts: Union[List[str], np.ndarray], # type: ignore[type-arg] - ) -> List[np.array]: # type: ignore[valid-type] - """ - Generate embeddings, treating the placeholder as an empty result. - """ - embeddings: List[np.array] = super().generate_embeddings(texts) # type: ignore[valid-type] - - for i, text in enumerate(texts): - if text == self.EMPTY_STRING_PLACEHOLDER: - embeddings[i] = np.zeros(self.ndims()) - - return embeddings diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 3904dcdb1a..728127f833 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,3 +1,4 @@ +import multiprocessing from typing import Iterator, Generator, Any, List, Mapping import lancedb # type: ignore @@ -536,17 +537,26 @@ def search_data_resource() -> Generator[Mapping[str, object], Any, None]: db_client_uri = client.db_client.uri table_name = client.make_qualified_table_name("search_data_resource") - # A new python process doesn't seem to correctly deserialize the custom embedding functions into global __REGISTRY__. - EmbeddingFunctionRegistry.get_instance().reset() + # A new python process doesn't seem to correctly deserialize the custom embedding + # functions into global __REGISTRY__. + # We make sure to reset it as well to make sure no globals are propagated to the spawned process. + EmbeddingFunctionRegistry().reset() + with multiprocessing.get_context("spawn").Pool(1) as pool: + results = pool.apply(run_lance_search_in_separate_process, (db_client_uri, table_name)) + + assert results[0]["text"] == "Frodo was a happy puppy" + + +def run_lance_search_in_separate_process(db_client_uri: str, table_name: str) -> Any: + import lancedb # Must read into __REGISTRY__ here. db = lancedb.connect(db_client_uri) tbl = db[table_name] tbl.checkout_latest() - results = ( + return ( tbl.search("puppy", query_type="vector", ordering_field_name="_distance") .select(["text"]) .to_list() ) - assert results[0]["text"] == "Frodo was a happy puppy" From 8afa7e1f400a06d412c890ea24702688b7f37e5e Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 29 Aug 2024 22:50:20 +0200 Subject: [PATCH 083/113] Simplify null string handling Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index f28fdac78d..9b240d58cb 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -725,11 +725,13 @@ def run(self) -> None: if (self._job_client.config.embedding_model_provider == "openai") and ( source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) ): - records: List[Dict[str, Any]] - for record in records: - for k, v in record.items(): - if k in source_columns and not v: - record[k] = EMPTY_STRING_PLACEHOLDER + records = [ + { + k: EMPTY_STRING_PLACEHOLDER if k in source_columns and v in ("", None) else v + for k, v in record.items() + } + for record in records + ] if self._load_table not in self._schema.dlt_tables(): for record in records: From 239543266bb8f53a6eb319f06998cf0aac53e546 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 30 Aug 2024 14:32:01 +0200 Subject: [PATCH 084/113] Change NULL string replacement with random string, doc clarification Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 2 +- .../website/docs/dlt-ecosystem/destinations/lancedb.md | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 9b240d58cb..e9acf651a3 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -78,7 +78,7 @@ TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} -EMPTY_STRING_PLACEHOLDER = "__EMPTY_STRING_PLACEHOLDER__" +EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" class LanceDBTypeMapper(TypeMapper): diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index 8b7f3854ee..5e52f8d6ab 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -216,11 +216,21 @@ The LanceDB destination supports syncing of the `dlt` state. ## Current Limitations +### In-Memory Tables + Adding new fields to an existing LanceDB table requires loading the entire table data into memory as a PyArrow table. This is because PyArrow tables are immutable, so adding fields requires creating a new table with the updated schema. For huge tables, this may impact performance and memory usage since the full table must be loaded into memory to add the new fields. Keep these considerations in mind when working with large datasets and monitor memory usage if adding fields to sizable existing tables. +### Null string handling for OpenAI embeddings + +OpenAI embedding service doesn't accept empty string bodies. We deal with this by replacing empty strings with a placeholder that should be very semantically dissimilar to 99.9% of queries. + +If your source column (column which is embedded) has empty values, it is important to consider the impact of this. There might be a _slight_ change that semantic queries can hit these empty strings. + +We reported this issue to LanceDB: https://github.com/lancedb/lancedb/issues/1577. + From 9a347e63d95f8fa451190a42ed9c0f4f33fca769 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 31 Aug 2024 18:17:26 +0200 Subject: [PATCH 085/113] Update default vector column name in docs Signed-off-by: Marcel Coetzee --- docs/website/docs/dlt-ecosystem/destinations/lancedb.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index 5e52f8d6ab..0d726508e6 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -201,7 +201,7 @@ This is the default disposition. It will append the data to the existing data in ## Additional Destination Options - `dataset_separator`: The character used to separate the dataset name from table names. Defaults to "___". -- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector__". +- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector". - `id_field_name`: The name of the special field used for deduplication and merging. Defaults to "id__". - `max_retries`: The maximum number of retries for embedding operations. Set to 0 to disable retries. Defaults to 3. From c0bedb7c90aa107dee01d6af46eefabfe54cbba6 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 10:17:35 +0200 Subject: [PATCH 086/113] Set `remove_orphans` flag to False on tests that don't require it Signed-off-by: Marcel Coetzee --- .../impl/lancedb/configuration.py | 2 +- dlt/destinations/impl/lancedb/factory.py | 2 +- .../impl/lancedb/lancedb_adapter.py | 20 +++++--- .../impl/lancedb/lancedb_client.py | 24 +++++++--- dlt/destinations/impl/lancedb/utils.py | 12 +++-- .../dlt-ecosystem/destinations/lancedb.md | 3 +- tests/load/lancedb/test_merge.py | 46 +++++++++++++++---- tests/load/lancedb/test_pipeline.py | 29 ++++-------- tests/load/lancedb/utils.py | 3 +- 9 files changed, 88 insertions(+), 53 deletions(-) diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index 92f88d562b..8f6a192bb0 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -91,7 +91,7 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): but it is configurable in rare cases. Make sure it corresponds with the associated embedding model's dimensionality.""" - vector_field_name: str = "vector__" + vector_field_name: str = "vector" """Name of the special field to store the vector embeddings.""" sentinel_table_name: str = "dltSentinelTable" """Name of the sentinel table that encapsulates datasets. Since LanceDB has no diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index cd792b6cee..d99f0fa6ee 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -32,7 +32,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.recommended_file_size = 128_000_000 - caps.supported_merge_strategies = ["delete-insert", "upsert"] + caps.supported_merge_strategies = ["upsert"] return caps diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 8daef77f68..7a83b230f0 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -1,17 +1,20 @@ -from typing import Any +from typing import Any, Dict from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource +from dlt.extract.items import TTableHintTemplate VECTORIZE_HINT = "x-lancedb-embed" +REMOVE_ORPHANS_HINT = "x-lancedb-remove-orphans" def lancedb_adapter( data: Any, embed: TColumnNames = None, merge_key: TColumnNames = None, + remove_orphans: bool = True, ) -> DltResource: """Prepares data for the LanceDB destination by specifying which columns should be embedded. @@ -23,6 +26,8 @@ def lancedb_adapter( It can be a single column name as a string, or a list of column names. merge_key (TColumnNames, optional): Specify columns to merge on. It can be a single column name as a string, or a list of column names. + remove_orphans (bool): Specify whether to remove orphaned records in child + tables with no parent records after merges to maintain referential integrity. Returns: DltResource: A resource with applied LanceDB-specific hints. @@ -37,6 +42,7 @@ def lancedb_adapter( """ resource = get_resource_for_adapter(data) + additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} column_hints: TTableSchemaColumns = {} if embed: @@ -67,12 +73,14 @@ def lancedb_adapter( "merge_key": True, } - if not column_hints: + additional_table_hints[REMOVE_ORPHANS_HINT] = remove_orphans + + if column_hints or additional_table_hints: + resource.apply_hints(columns=column_hints, additional_table_hints=additional_table_hints) + else: raise ValueError( - "You must must provide at least either the 'embed' or 'merge_key' argument if using the" - " adapter." + "You must must provide at least either the 'embed' or 'merge_key' or 'remove_orphans'" + " argument if using the adapter." ) - else: - resource.apply_hints(columns=column_hints) return resource diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index bcb069f628..3e373b2e83 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -43,15 +43,14 @@ FollowupJobRequest, ) from dlt.common.pendulum import timedelta -from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TColumnSchema +from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import ( TColumnType, - TTableFormat, TTableSchemaColumns, TWriteDisposition, TColumnSchema, ) -from dlt.common.schema.utils import get_columns_names_with_prop, DEFAULT_MERGE_STRATEGY +from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -61,6 +60,7 @@ ) from dlt.destinations.impl.lancedb.lancedb_adapter import ( VECTORIZE_HINT, + REMOVE_ORPHANS_HINT, ) from dlt.destinations.impl.lancedb.schema import ( make_arrow_field_schema, @@ -75,7 +75,6 @@ from dlt.destinations.impl.lancedb.utils import ( set_non_standard_providers_environment_variables, get_default_arrow_value, - get_lancedb_merge_key, IterableWrapper, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest @@ -712,8 +711,9 @@ def create_table_chain_completed_followup_jobs( ) # Orphan removal is only supported for upsert strategy because we need a deterministic key hash. first_table_in_chain = table_chain[0] - merge_strategy = first_table_in_chain.get("x-merge-strategy", DEFAULT_MERGE_STRATEGY) - if first_table_in_chain.get("write_disposition") == "merge" and merge_strategy == "upsert": + if first_table_in_chain.get("write_disposition") == "merge" and first_table_in_chain.get( + REMOVE_ORPHANS_HINT + ): all_job_paths_ordered = [ job.file_path for table in table_chain @@ -754,6 +754,18 @@ def run(self) -> None: merge_key = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] + # We need upsert merge's deterministic _dlt_id to perform orphan removal. + # Hence, we require at least a primary key on the root table if the merge disposition is chosen. + if ( + (self._load_table not in self._schema.dlt_table_names()) + and not self._load_table.get("parent") + and (write_disposition == "merge") + and (not get_columns_names_with_prop(self._load_table, "primary_key")) + ): + raise DestinationTerminalException( + "LanceDB's write disposition requires at least one explicit primary key." + ) + write_records( arrow_table, db_client=db_client, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 735500f9f1..ef71ee1b28 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,10 +1,10 @@ import os -from datetime import date, datetime from typing import Union, Dict, Optional, TypeVar, Generic, Iterable, Iterator import pyarrow as pa from dlt.common import logger +from dlt.common.pendulum import __utcnow from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider @@ -36,9 +36,9 @@ def get_default_arrow_value(field_type: TArrowDataType) -> object: elif pa.types.is_boolean(field_type): return False elif pa.types.is_date(field_type): - return date.today() + return __utcnow().today() elif pa.types.is_timestamp(field_type): - return datetime.now() + return __utcnow() else: raise ValueError(f"Unsupported data type: {field_type}") @@ -55,11 +55,13 @@ def __init__(self, iterable: Iterable[ItemType]) -> None: def __iter__(self) -> Iterator[ItemType]: return iter(self.iterable) - def iter(self) -> Iterator[ItemType]: + def iter(self) -> Iterator[ItemType]: # noqa: A003 return iter(self.iterable) -def get_lancedb_merge_key(load_table: TTableSchema) -> Optional[Union[str, IterableWrapper[str]]]: +def get_lancedb_orphan_removal_merge_key( + load_table: TTableSchema, +) -> Optional[Union[str, IterableWrapper[str]]]: if merge_key := get_columns_names_with_prop(load_table, "merge_key"): return merge_key[0] if len(merge_key) == 1 else IterableWrapper(merge_key) elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index 8b7f3854ee..ecdab8bc40 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -201,8 +201,7 @@ This is the default disposition. It will append the data to the existing data in ## Additional Destination Options - `dataset_separator`: The character used to separate the dataset name from table names. Defaults to "___". -- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector__". -- `id_field_name`: The name of the special field used for deduplication and merging. Defaults to "id__". +- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector". - `max_retries`: The maximum number of retries for embedding operations. Set to 0 to disable retries. Defaults to 3. diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py index 120e6f4233..3d4380b4a1 100644 --- a/tests/load/lancedb/test_merge.py +++ b/tests/load/lancedb/test_merge.py @@ -8,7 +8,7 @@ from pandas.testing import assert_frame_equal import dlt -from dlt.common.typing import DictStrAny +from dlt.common.typing import DictStrAny, DictStrStr from dlt.common.utils import uniq_id from dlt.destinations.impl.lancedb.lancedb_adapter import ( lancedb_adapter, @@ -16,6 +16,7 @@ from tests.load.lancedb.utils import chunk_document from tests.load.utils import ( drop_active_pipeline_data, + sequence_generator, ) from tests.pipeline.utils import ( assert_load_info, @@ -264,9 +265,9 @@ def documents_source( df = tbl.to_pandas() # Check (non-empty) embeddings as present, and that orphaned embeddings have been discarded. - assert len(df)==21 - assert "vector__" in df.columns - for _, vector in enumerate(df["vector__"]): + assert len(df) == 21 + assert "vector" in df.columns + for _, vector in enumerate(df["vector"]): assert isinstance(vector, np.ndarray) assert vector.size > 0 @@ -281,7 +282,7 @@ def test_lancedb_compound_merge_key_root_table() -> None: @dlt.resource( table_name="root", - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, primary_key=["doc_id", "chunk_hash"], merge_key=["doc_id", "chunk_hash"], ) @@ -324,10 +325,10 @@ def identity_resource( tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash", "foo"]).reset_index(drop=True) )[["doc_id", "chunk_hash", "foo"]] - assert_frame_equal(actual_root_df, expected_root_table_df) + assert_frame_equal(actual_root_df, expected_root_table_df) -def test_lancedb_compound_merge_key_root_table() -> None: +def test_lancedb_compound_merge_key_root_table_no_orphan_removal() -> None: pipeline = dlt.pipeline( pipeline_name="test_lancedb_compound_merge_key", destination="lancedb", @@ -337,7 +338,7 @@ def test_lancedb_compound_merge_key_root_table() -> None: @dlt.resource( table_name="root", - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, primary_key=["doc_id", "chunk_hash"], merge_key=["doc_id", "chunk_hash"], ) @@ -346,6 +347,8 @@ def identity_resource( ) -> Generator[List[DictStrAny], None, None]: yield data + lancedb_adapter(identity_resource, remove_orphans=False) + run_1 = [ {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, @@ -365,7 +368,6 @@ def identity_resource( pd.DataFrame( data=[ {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, - {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, ] @@ -381,4 +383,28 @@ def identity_resource( tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash", "foo"]).reset_index(drop=True) )[["doc_id", "chunk_hash", "foo"]] - assert_frame_equal(actual_root_df, expected_root_table_df) + assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_must_provide_at_least_primary_key_on_merge_disposition() -> None: + """We need upsert merge's deterministic _dlt_id to perform orphan removal. + Hence, we require at least the primary key required (raises exception if missing). + Specify a merge key for custom orphan identification.""" + generator_instance1 = sequence_generator() + + @dlt.resource(write_disposition={"disposition": "merge", "strategy": "upsert"}) + def some_data() -> Generator[DictStrStr, Any, None]: + yield from next(generator_instance1) + + pipeline = dlt.pipeline( + pipeline_name="test_must_provide_both_primary_and_merge_key_on_merge_disposition", + destination="lancedb", + dataset_name=( + f"test_must_provide_both_primary_and_merge_key_on_merge_disposition{uniq_id()}" + ), + ) + with pytest.raises(Exception): + load_info = pipeline.run( + some_data(), + ) + assert_load_info(load_info) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 4030f92f78..572d783629 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -137,7 +137,7 @@ def test_explicit_append() -> None: {"doc_id": 3, "content": "3"}, ] - @dlt.resource(merge_key="doc_id") + @dlt.resource() def some_data() -> Generator[List[DictStrAny], Any, None]: yield data @@ -276,23 +276,11 @@ def test_pipeline_merge() -> None: }, ] - @dlt.resource() + @dlt.resource(primary_key=["doc_id"]) def movies_data() -> Any: yield data - @dlt.resource(merge_key="doc_id") - def movies_data_explicit_merge_keys() -> Any: - yield data - - lancedb_adapter( - movies_data, - embed=["description"], - ) - - lancedb_adapter( - movies_data_explicit_merge_keys, - embed=["description"], - ) + lancedb_adapter(movies_data, embed=["description"], remove_orphans=False) pipeline = dlt.pipeline( pipeline_name="movies", @@ -301,7 +289,7 @@ def movies_data_explicit_merge_keys() -> Any: ) info = pipeline.run( movies_data(), - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, dataset_name=f"MoviesDataset{uniq_id()}", ) assert_load_info(info) @@ -312,7 +300,7 @@ def movies_data_explicit_merge_keys() -> Any: info = pipeline.run( movies_data(), - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, ) assert_load_info(info) assert_table(pipeline, "movies_data", items=data) @@ -386,9 +374,10 @@ def test_merge_github_nested() -> None: data = json.load(f) info = pipe.run( - lancedb_adapter(data[:17], embed=["title", "body"], merge_key="id"), + lancedb_adapter(data[:17], embed=["title", "body"], remove_orphans=False), table_name="issues", - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key="id", ) assert_load_info(info) # assert if schema contains tables with right names @@ -438,7 +427,7 @@ def test_merge_no_orphans() -> None: @dlt.resource( write_disposition={"disposition": "merge", "strategy": "upsert"}, table_name="document", - primary_key=["doc_id", "chunk_hash"], + primary_key=["doc_id"], merge_key=["doc_id"], ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index 8dd56d22aa..8e2fddfba5 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -51,8 +51,7 @@ def assert_table( drop_keys = [ "_dlt_id", "_dlt_load_id", - dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", - dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector", ] objects_without_dlt_or_special_keys = [ {k: v for k, v in record.items() if k not in drop_keys} for record in records From 5f0d6203d8e822c4da16f233c5f7c6c0a938e435 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 11:04:23 +0200 Subject: [PATCH 087/113] Implement starter arrow string placeholder function Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/utils.py | 44 ++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index ef71ee1b28..63c99b167b 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,5 +1,5 @@ import os -from typing import Union, Dict, Optional, TypeVar, Generic, Iterable, Iterator +from typing import Union, Dict, Optional, TypeVar, Generic, Iterable, Iterator, List import pyarrow as pa @@ -10,7 +10,7 @@ from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider from dlt.destinations.impl.lancedb.schema import TArrowDataType - +EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { "cohere": "COHERE_API_KEY", "gemini-text": "GOOGLE_API_KEY", @@ -77,3 +77,43 @@ def get_lancedb_orphan_removal_merge_key( return unique_key[0] if len(unique_key) == 1 else IterableWrapper(unique_key) else: return None + + +def fill_empty_source_column_values_with_placeholder( + table: pa.Table, source_columns: List[str], placeholder: str +) -> pa.Table: + """ + Replaces empty strings in the specified source columns of an Arrow table with a placeholder string. + + Args: + table (pa.Table): The input Arrow table. + source_columns (List[str]): A list of column names to replace empty strings in. + placeholder (str): The placeholder string to use for replacement. + + Returns: + pa.Table: The modified Arrow table with empty strings replaced in the specified columns. + """ + # Create a new table with the same schema as the input table. + new_table = table + + # Iterate over each column that needs to be modified + for column_name in source_columns: + # Get the column index + column_index = new_table.schema.get_field_index(column_name) + + # Get the column as an array + column_array = new_table.column(column_name).to_pandas() + + # Replace empty strings with the placeholder + column_array = column_array.apply(lambda x: placeholder if x=="" else x) + + # Create a new array with the modified values + new_array = pa.array(column_array) + + # Replace the column in the new table with the modified array + new_table = new_table.set_column(column_index, column_name, new_array) + + return new_table + + + From b7f30769ba6806d5c6294ef472a4a7a689d9d7e0 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 11:05:22 +0200 Subject: [PATCH 088/113] Add test for empty arrow string element vectorised replacement utility function Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/load/lancedb/test_utils.py diff --git a/tests/load/lancedb/test_utils.py b/tests/load/lancedb/test_utils.py new file mode 100644 index 0000000000..8dbf5720c9 --- /dev/null +++ b/tests/load/lancedb/test_utils.py @@ -0,0 +1,26 @@ +import pyarrow as pa +import pytest + +from dlt.destinations.impl.lancedb.utils import fill_empty_source_column_values_with_placeholder + + +# Mark all tests as essential, don't remove. +pytestmark = pytest.mark.essential + + +def test_fill_empty_source_column_values_with_placeholder() -> None: + data = [pa.array(["", "hello", ""]), pa.array([1, 2, 3]), pa.array(["world", "", "arrow"])] + table = pa.Table.from_arrays(data, names=["A", "B", "C"]) + + source_columns = ["A"] + placeholder = "placeholder" + + new_table = fill_empty_source_column_values_with_placeholder(table, source_columns, placeholder) + + expected_data = [ + pa.array(["placeholder", "hello", "placeholder"]), + pa.array([1, 2, 3]), + pa.array(["world", "", "arrow"]), + ] + expected_table = pa.Table.from_arrays(expected_data, names=["A", "B", "C"]) + assert new_table.equals(expected_table) From e3a4ed0024716fd355b0278deb446c8a099dbb4f Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 11:15:02 +0200 Subject: [PATCH 089/113] Handle NULL values in addition to empty strings in arrow substitution method Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/utils.py | 45 ++++++++++++++------------ tests/load/lancedb/test_utils.py | 14 +++++--- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 63c99b167b..a12c0425c2 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -83,37 +83,40 @@ def fill_empty_source_column_values_with_placeholder( table: pa.Table, source_columns: List[str], placeholder: str ) -> pa.Table: """ - Replaces empty strings in the specified source columns of an Arrow table with a placeholder string. + Replaces empty strings and null values in the specified source columns of an Arrow table with a placeholder string. Args: table (pa.Table): The input Arrow table. - source_columns (List[str]): A list of column names to replace empty strings in. + source_columns (List[str]): A list of column names to replace empty strings and null values in. placeholder (str): The placeholder string to use for replacement. Returns: - pa.Table: The modified Arrow table with empty strings replaced in the specified columns. + pa.Table: The modified Arrow table with empty strings and null values replaced in the specified columns. """ - # Create a new table with the same schema as the input table. - new_table = table + new_columns = [] - # Iterate over each column that needs to be modified - for column_name in source_columns: - # Get the column index - column_index = new_table.schema.get_field_index(column_name) + for col_name in table.column_names: + column = table[col_name] + if col_name in source_columns: + # Process each chunk separately + new_chunks = [] + for chunk in column.chunks: + # Replace null values with the placeholder + filled_chunk = pa.compute.fill_null(chunk, fill_value=placeholder) + # Replace empty strings with the placeholder using regex + new_chunk = pa.compute.replace_substring_regex( + filled_chunk, pattern=r"^$", replacement=placeholder + ) + new_chunks.append(new_chunk) - # Get the column as an array - column_array = new_table.column(column_name).to_pandas() + # Combine the processed chunks into a new ChunkedArray + new_column = pa.chunked_array(new_chunks) + else: + new_column = column - # Replace empty strings with the placeholder - column_array = column_array.apply(lambda x: placeholder if x=="" else x) - - # Create a new array with the modified values - new_array = pa.array(column_array) - - # Replace the column in the new table with the modified array - new_table = new_table.set_column(column_index, column_name, new_array) - - return new_table + new_columns.append(new_column) + return pa.Table.from_arrays(new_columns, names=table.column_names) +1 diff --git a/tests/load/lancedb/test_utils.py b/tests/load/lancedb/test_utils.py index 8dbf5720c9..2f517aac8e 100644 --- a/tests/load/lancedb/test_utils.py +++ b/tests/load/lancedb/test_utils.py @@ -9,18 +9,24 @@ def test_fill_empty_source_column_values_with_placeholder() -> None: - data = [pa.array(["", "hello", ""]), pa.array([1, 2, 3]), pa.array(["world", "", "arrow"])] - table = pa.Table.from_arrays(data, names=["A", "B", "C"]) + data = [ + pa.array(["", "hello", ""]), + pa.array(["hello", None, ""]), + pa.array([1, 2, 3]), + pa.array(["world", "", "arrow"]), + ] + table = pa.Table.from_arrays(data, names=["A", "B", "C", "D"]) - source_columns = ["A"] + source_columns = ["A", "B"] placeholder = "placeholder" new_table = fill_empty_source_column_values_with_placeholder(table, source_columns, placeholder) expected_data = [ pa.array(["placeholder", "hello", "placeholder"]), + pa.array(["hello", "placeholder", "placeholder"]), pa.array([1, 2, 3]), pa.array(["world", "", "arrow"]), ] - expected_table = pa.Table.from_arrays(expected_data, names=["A", "B", "C"]) + expected_table = pa.Table.from_arrays(expected_data, names=["A", "B", "C", "D"]) assert new_table.equals(expected_table) From 4ec894fc4cd800c51963134584c91b6d27b1d2c2 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 11:22:30 +0200 Subject: [PATCH 090/113] More efficient empty value replacement with canonical arrow usage Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 22 +++-------- dlt/destinations/impl/lancedb/utils.py | 39 +++++-------------- tests/load/lancedb/test_pipeline.py | 5 +-- 3 files changed, 17 insertions(+), 49 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 7d83c3d542..cd9a372ecf 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -76,6 +76,8 @@ set_non_standard_providers_environment_variables, get_default_arrow_value, IterableWrapper, + EMPTY_STRING_PLACEHOLDER, + fill_empty_source_column_values_with_placeholder, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -88,7 +90,6 @@ TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} BATCH_PROCESS_CHUNK_SIZE = 10_000 -EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" class LanceDBTypeMapper(TypeMapper): @@ -749,23 +750,10 @@ def run(self) -> None: if (self._job_client.config.embedding_model_provider == "openai") and ( source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) ): - records = [ - { - k: EMPTY_STRING_PLACEHOLDER if k in source_columns and v in ("", None) else v - for k, v in record.items() - } - for record in records - ] + arrow_table = fill_empty_source_column_values_with_placeholder( + arrow_table, source_columns, EMPTY_STRING_PLACEHOLDER + ) - if self._load_table not in self._schema.dlt_tables(): - for record in records: - # Add reserved ID fields. - uuid_id = ( - generate_uuid(record, unique_identifiers, self._fq_table_name) - if unique_identifiers - else str(uuid.uuid4()) - ) - record.update({self._id_field_name: uuid_id}) merge_key = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] # We need upsert merge's deterministic _dlt_id to perform orphan removal. diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index a12c0425c2..07e0ca4bb3 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -63,18 +63,18 @@ def get_lancedb_orphan_removal_merge_key( load_table: TTableSchema, ) -> Optional[Union[str, IterableWrapper[str]]]: if merge_key := get_columns_names_with_prop(load_table, "merge_key"): - return merge_key[0] if len(merge_key) == 1 else IterableWrapper(merge_key) + return merge_key[0] if len(merge_key)==1 else IterableWrapper(merge_key) elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): # No merge key defined, warn and merge on the primary key. logger.warning( "Merge strategy selected without defined merge key - using primary key as merge key." ) - return primary_key[0] if len(primary_key) == 1 else IterableWrapper(merge_key) + return primary_key[0] if len(primary_key)==1 else IterableWrapper(merge_key) elif unique_key := get_columns_names_with_prop(load_table, "unique"): logger.warning( "Merge strategy selected without defined merge key - using unique key as merge key." ) - return unique_key[0] if len(unique_key) == 1 else IterableWrapper(unique_key) + return unique_key[0] if len(unique_key)==1 else IterableWrapper(unique_key) else: return None @@ -93,30 +93,11 @@ def fill_empty_source_column_values_with_placeholder( Returns: pa.Table: The modified Arrow table with empty strings and null values replaced in the specified columns. """ - new_columns = [] - - for col_name in table.column_names: + for col_name in source_columns: column = table[col_name] - if col_name in source_columns: - # Process each chunk separately - new_chunks = [] - for chunk in column.chunks: - # Replace null values with the placeholder - filled_chunk = pa.compute.fill_null(chunk, fill_value=placeholder) - # Replace empty strings with the placeholder using regex - new_chunk = pa.compute.replace_substring_regex( - filled_chunk, pattern=r"^$", replacement=placeholder - ) - new_chunks.append(new_chunk) - - # Combine the processed chunks into a new ChunkedArray - new_column = pa.chunked_array(new_chunks) - else: - new_column = column - - new_columns.append(new_column) - - return pa.Table.from_arrays(new_columns, names=table.column_names) - - -1 + filled_column = pa.compute.fill_null(column, fill_value=placeholder) + new_column = pa.compute.replace_substring_regex( + filled_column, pattern=r"^$", replacement=placeholder + ) + table = table.set_column(table.column_names.index(col_name), col_name, new_column) + return table diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 93c7b55487..e79c0c4dd4 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,11 +1,10 @@ import multiprocessing -from typing import Iterator, Generator, Any, List, Mapping from typing import Iterator, Generator, Any, List +from typing import Mapping from typing import Union, Dict -import lancedb # type: ignore import pytest -from lancedb import DBConnection +from lancedb import DBConnection # type: ignore from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore from lancedb.table import Table # type: ignore From 9866874f23f2732c74b93be653a705cb7cdbf8ca Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 11:25:37 +0200 Subject: [PATCH 091/113] Format Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 07e0ca4bb3..ee556b3553 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -63,18 +63,18 @@ def get_lancedb_orphan_removal_merge_key( load_table: TTableSchema, ) -> Optional[Union[str, IterableWrapper[str]]]: if merge_key := get_columns_names_with_prop(load_table, "merge_key"): - return merge_key[0] if len(merge_key)==1 else IterableWrapper(merge_key) + return merge_key[0] if len(merge_key) == 1 else IterableWrapper(merge_key) elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): # No merge key defined, warn and merge on the primary key. logger.warning( "Merge strategy selected without defined merge key - using primary key as merge key." ) - return primary_key[0] if len(primary_key)==1 else IterableWrapper(merge_key) + return primary_key[0] if len(primary_key) == 1 else IterableWrapper(merge_key) elif unique_key := get_columns_names_with_prop(load_table, "unique"): logger.warning( "Merge strategy selected without defined merge key - using unique key as merge key." ) - return unique_key[0] if len(unique_key)==1 else IterableWrapper(unique_key) + return unique_key[0] if len(unique_key) == 1 else IterableWrapper(unique_key) else: return None From 7099d5f581f3c3cac77979ef691790450a234e6c Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 12:29:50 +0200 Subject: [PATCH 092/113] Bump pyarrow version Signed-off-by: Marcel Coetzee --- poetry.lock | 108 +++++++++++++++++++++++++------------------------ pyproject.toml | 2 +- 2 files changed, 57 insertions(+), 53 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1bfdb776a2..0ce139d08f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2102,32 +2102,33 @@ typing-extensions = ">=3.10.0" [[package]] name = "databricks-sql-connector" -version = "3.3.0" +version = "2.9.6" description = "Databricks SQL Connector for Python" optional = true -python-versions = "<4.0.0,>=3.8.0" +python-versions = "<4.0.0,>=3.7.1" files = [ - {file = "databricks_sql_connector-3.3.0-py3-none-any.whl", hash = "sha256:55ee5a4a11291bf91a235ac76e41b419ddd66a9a321065a8bfaf119acbb26d6b"}, - {file = "databricks_sql_connector-3.3.0.tar.gz", hash = "sha256:19e82965da4c86574adfe9f788c17b4494d98eb8075ba4fd4306573d2edbf194"}, + {file = "databricks_sql_connector-2.9.6-py3-none-any.whl", hash = "sha256:d830abf86e71d2eb83c6a7b7264d6c03926a8a83cec58541ddd6b83d693bde8f"}, + {file = "databricks_sql_connector-2.9.6.tar.gz", hash = "sha256:e55f5b8ede8ae6c6f31416a4cf6352f0ac019bf6875896c668c7574ceaf6e813"}, ] [package.dependencies] +alembic = ">=1.0.11,<2.0.0" lz4 = ">=4.0.2,<5.0.0" numpy = [ - {version = ">=1.16.6,<2.0.0", markers = "python_version >= \"3.8\" and python_version < \"3.11\""}, - {version = ">=1.23.4,<2.0.0", markers = "python_version >= \"3.11\""}, + {version = ">=1.16.6", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=1.23.4", markers = "python_version >= \"3.11\""}, ] oauthlib = ">=3.1.0,<4.0.0" openpyxl = ">=3.0.10,<4.0.0" -pandas = {version = ">=1.2.5,<2.2.0", markers = "python_version >= \"3.8\""} -pyarrow = ">=14.0.1,<17" +pandas = {version = ">=1.2.5,<3.0.0", markers = "python_version >= \"3.8\""} +pyarrow = [ + {version = ">=6.0.0", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=10.0.1", markers = "python_version >= \"3.11\""}, +] requests = ">=2.18.1,<3.0.0" -thrift = ">=0.16.0,<0.21.0" -urllib3 = ">=1.26" - -[package.extras] -alembic = ["alembic (>=1.0.11,<2.0.0)", "sqlalchemy (>=2.0.21)"] -sqlalchemy = ["sqlalchemy (>=2.0.21)"] +sqlalchemy = ">=1.3.24,<2.0.0" +thrift = ">=0.16.0,<0.17.0" +urllib3 = ">=1.0" [[package]] name = "dbt-athena-community" @@ -6659,52 +6660,55 @@ files = [ [[package]] name = "pyarrow" -version = "16.1.0" +version = "17.0.0" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, - {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, - {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, - {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, - {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, - {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, - {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, - {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, + {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, + {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, + {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] numpy = ">=1.16.6" +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + [[package]] name = "pyasn1" version = "0.5.0" @@ -9722,4 +9726,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "888e1760984e867fde690a1cca90330e255d69a8775c81020d003650def7ab4c" +content-hash = "1d8fa59c9ef876d699cb5b5a2fcadb9a78c4c4d28a9fca7ca0e83147c08feaae" diff --git a/pyproject.toml b/pyproject.toml index 1bdaf77b86..53ef7a5d94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,7 +219,7 @@ dbt-duckdb = ">=1.2.0" pymongo = ">=4.3.3" pandas = ">2" alive-progress = ">=3.0.1" -pyarrow = ">=14.0.0" +pyarrow = ">=17.0.0" psycopg2-binary = ">=2.9" lancedb = { version = ">=0.8.2", markers = "python_version >= '3.9'", allow-prereleases = true } openai = ">=1.35" From 1c770d1d7699bf5cb7756d6ab61e7f1bc64b2b9b Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 12:42:13 +0200 Subject: [PATCH 093/113] Use pa.nulls instead of [None]*len Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index cd9a372ecf..873c69cae3 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -833,8 +833,9 @@ def run(self) -> None: payload_arrow_table = payload_arrow_table.append_column(field, default_array) except ValueError as e: logger.warn(f"{e}. Using null values for field '{field.name}'.") - null_array = pa.array([None] * payload_arrow_table.num_rows, type=field.type) - payload_arrow_table = payload_arrow_table.append_column(field, null_array) + payload_arrow_table = payload_arrow_table.append_column( + field, pa.nulls(size=payload_arrow_table.num_rows, type=field.type) + ) write_records( payload_arrow_table, From 0b11ac7730585975621ffe47153abca299e72cf4 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 13:04:04 +0200 Subject: [PATCH 094/113] Update tests Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_merge.py | 79 ++++++++------------------------ 1 file changed, 19 insertions(+), 60 deletions(-) diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py index 3d4380b4a1..ca1fe66d46 100644 --- a/tests/load/lancedb/test_merge.py +++ b/tests/load/lancedb/test_merge.py @@ -76,16 +76,24 @@ def identity_resource( { "id": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}], - }, # Removed one child and one grandchild + }, # Removed one child and one grandchild. { "id": 2, "child": [{"bar": 4, "grandchild": [{"baz": 8}]}], - }, # Changed child and grandchild + }, # Changed child and grandchild. ] info = pipeline.run(identity_resource(run_2)) assert_load_info(info) with pipeline.destination_client() as client: + expected_parent_data = pd.DataFrame( + data=[ + {"id": 1}, + {"id": 2}, + {"id": 3}, + ] + ) + expected_child_data = pd.DataFrame( data=[ {"bar": 1}, @@ -105,32 +113,39 @@ def identity_resource( ] ) + parent_table_name = client.make_qualified_table_name("parent") # type: ignore[attr-defined] child_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] grandchild_table_name = client.make_qualified_table_name( # type: ignore[attr-defined] "parent__child__grandchild" ) + parent_tbl = client.db_client.open_table(parent_table_name) # type: ignore[attr-defined] child_tbl = client.db_client.open_table(child_table_name) # type: ignore[attr-defined] grandchild_tbl = client.db_client.open_table(grandchild_table_name) # type: ignore[attr-defined] + actual_parent_df = ( + parent_tbl.to_pandas() + .sort_values(by="id") + .reset_index(drop=True) + ) actual_child_df = ( child_tbl.to_pandas() .sort_values(by="bar") .reset_index(drop=True) - .reset_index(drop=True) ) actual_grandchild_df = ( grandchild_tbl.to_pandas() .sort_values(by="baz") .reset_index(drop=True) - .reset_index(drop=True) ) + expected_parent_data = expected_parent_data.sort_values(by="id").reset_index(drop=True) expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( drop=True ) + assert_frame_equal(actual_parent_df[["id"]], expected_parent_data) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -280,62 +295,6 @@ def test_lancedb_compound_merge_key_root_table() -> None: dev_mode=True, ) - @dlt.resource( - table_name="root", - write_disposition={"disposition": "merge", "strategy": "upsert"}, - primary_key=["doc_id", "chunk_hash"], - merge_key=["doc_id", "chunk_hash"], - ) - def identity_resource( - data: List[DictStrAny], - ) -> Generator[List[DictStrAny], None, None]: - yield data - - run_1 = [ - {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, - {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, - ] - info = pipeline.run(identity_resource(run_1)) - assert_load_info(info) - - run_2 = [ - {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, - {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, - ] - info = pipeline.run(identity_resource(run_2)) - assert_load_info(info) - - with pipeline.destination_client() as client: - expected_root_table_df = ( - pd.DataFrame( - data=[ - {"doc_id": 1, "chunk_hash": "a", "foo": "aat"}, - {"doc_id": 1, "chunk_hash": "b", "foo": "coo"}, - {"doc_id": 1, "chunk_hash": "c", "foo": "loot"}, - ] - ) - .sort_values(by=["doc_id", "chunk_hash", "foo"]) - .reset_index(drop=True) - ) - - root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] - tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] - - actual_root_df: DataFrame = ( - tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash", "foo"]).reset_index(drop=True) - )[["doc_id", "chunk_hash", "foo"]] - - assert_frame_equal(actual_root_df, expected_root_table_df) - - -def test_lancedb_compound_merge_key_root_table_no_orphan_removal() -> None: - pipeline = dlt.pipeline( - pipeline_name="test_lancedb_compound_merge_key", - destination="lancedb", - dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", - dev_mode=True, - ) - @dlt.resource( table_name="root", write_disposition={"disposition": "merge", "strategy": "upsert"}, From e81736e9d1ed1d177cfd4a607608dbd707337c65 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 13:37:04 +0200 Subject: [PATCH 095/113] Invert remove orphans flag Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_adapter.py | 8 ++++---- dlt/destinations/impl/lancedb/lancedb_client.py | 6 +++--- tests/load/lancedb/test_merge.py | 2 +- tests/load/lancedb/test_pipeline.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 7a83b230f0..8f4fbb091d 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -7,14 +7,14 @@ VECTORIZE_HINT = "x-lancedb-embed" -REMOVE_ORPHANS_HINT = "x-lancedb-remove-orphans" +NO_REMOVE_ORPHANS_HINT = "x-lancedb-remove-orphans" def lancedb_adapter( data: Any, embed: TColumnNames = None, merge_key: TColumnNames = None, - remove_orphans: bool = True, + no_remove_orphans: bool = False, ) -> DltResource: """Prepares data for the LanceDB destination by specifying which columns should be embedded. @@ -26,7 +26,7 @@ def lancedb_adapter( It can be a single column name as a string, or a list of column names. merge_key (TColumnNames, optional): Specify columns to merge on. It can be a single column name as a string, or a list of column names. - remove_orphans (bool): Specify whether to remove orphaned records in child + no_remove_orphans (bool): Specify whether to remove orphaned records in child tables with no parent records after merges to maintain referential integrity. Returns: @@ -73,7 +73,7 @@ def lancedb_adapter( "merge_key": True, } - additional_table_hints[REMOVE_ORPHANS_HINT] = remove_orphans + additional_table_hints[NO_REMOVE_ORPHANS_HINT] = no_remove_orphans if column_hints or additional_table_hints: resource.apply_hints(columns=column_hints, additional_table_hints=additional_table_hints) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 873c69cae3..d4a3d53284 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -60,7 +60,7 @@ ) from dlt.destinations.impl.lancedb.lancedb_adapter import ( VECTORIZE_HINT, - REMOVE_ORPHANS_HINT, + NO_REMOVE_ORPHANS_HINT, ) from dlt.destinations.impl.lancedb.schema import ( make_arrow_field_schema, @@ -704,8 +704,8 @@ def create_table_chain_completed_followup_jobs( ) # Orphan removal is only supported for upsert strategy because we need a deterministic key hash. first_table_in_chain = table_chain[0] - if first_table_in_chain.get("write_disposition") == "merge" and first_table_in_chain.get( - REMOVE_ORPHANS_HINT + if first_table_in_chain.get("write_disposition") == "merge" and not first_table_in_chain.get( + NO_REMOVE_ORPHANS_HINT ): all_job_paths_ordered = [ job.file_path diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py index ca1fe66d46..95b721295b 100644 --- a/tests/load/lancedb/test_merge.py +++ b/tests/load/lancedb/test_merge.py @@ -306,7 +306,7 @@ def identity_resource( ) -> Generator[List[DictStrAny], None, None]: yield data - lancedb_adapter(identity_resource, remove_orphans=False) + lancedb_adapter(identity_resource, no_remove_orphans=True) run_1 = [ {"doc_id": 1, "chunk_hash": "a", "foo": "bar"}, diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index e79c0c4dd4..aef36aa315 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -284,7 +284,7 @@ def test_pipeline_merge() -> None: def movies_data() -> Any: yield data - lancedb_adapter(movies_data, embed=["description"], remove_orphans=False) + lancedb_adapter(movies_data, embed=["description"], no_remove_orphans=True) pipeline = dlt.pipeline( pipeline_name="movies", @@ -378,7 +378,7 @@ def test_merge_github_nested() -> None: data = json.load(f) info = pipe.run( - lancedb_adapter(data[:17], embed=["title", "body"], remove_orphans=False), + lancedb_adapter(data[:17], embed=["title", "body"], no_remove_orphans=True), table_name="issues", write_disposition={"disposition": "merge", "strategy": "upsert"}, primary_key="id", From 36abec75665e398a457a6b0116c6c0b5598633e8 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 16:52:33 +0200 Subject: [PATCH 096/113] Implement root table orphan deletion, only integer doc_ids Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 52 ++++++++---- dlt/destinations/impl/lancedb/utils.py | 44 ++++------ tests/load/lancedb/test_merge.py | 82 ++++++++++++++++--- 3 files changed, 121 insertions(+), 57 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index d4a3d53284..1a29d7cbe9 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -17,6 +17,7 @@ import lancedb # type: ignore import lancedb.table # type: ignore import pyarrow as pa +import pyarrow.compute as pc import pyarrow.parquet as pq from lancedb import DBConnection from lancedb.common import DATA # type: ignore @@ -75,9 +76,9 @@ from dlt.destinations.impl.lancedb.utils import ( set_non_standard_providers_environment_variables, get_default_arrow_value, - IterableWrapper, EMPTY_STRING_PLACEHOLDER, fill_empty_source_column_values_with_placeholder, + get_canonical_vector_database_doc_id_merge_key, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -171,8 +172,9 @@ def write_records( db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", - merge_key: Optional[Union[str, IterableWrapper[str]]] = None, + merge_key: Optional[str] = None, remove_orphans: Optional[bool] = False, + filter_condition: Optional[str] = None, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -183,6 +185,8 @@ def write_records( merge_key: Keys for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. remove_orphans (bool): Whether to remove orphans after insertion or not (only merge disposition). + filter_condition (str): If None, then all such rows will be deleted. + Otherwise, the condition will be used as an SQL filter to limit what rows are deleted. Raises: ValueError: If the write disposition is unsupported, or `id_field_name` is not @@ -204,7 +208,9 @@ def write_records( tbl.add(records, mode="overwrite") elif write_disposition == "merge": if remove_orphans: - tbl.merge_insert(merge_key).when_not_matched_by_source_delete().execute(records) + tbl.merge_insert(merge_key).when_not_matched_by_source_delete( + filter_condition + ).execute(records) else: tbl.merge_insert( merge_key @@ -704,9 +710,9 @@ def create_table_chain_completed_followup_jobs( ) # Orphan removal is only supported for upsert strategy because we need a deterministic key hash. first_table_in_chain = table_chain[0] - if first_table_in_chain.get("write_disposition") == "merge" and not first_table_in_chain.get( - NO_REMOVE_ORPHANS_HINT - ): + if first_table_in_chain.get( + "write_disposition" + ) == "merge" and not first_table_in_chain.get(NO_REMOVE_ORPHANS_HINT): all_job_paths_ordered = [ job.file_path for table in table_chain @@ -837,14 +843,32 @@ def run(self) -> None: field, pa.nulls(size=payload_arrow_table.num_rows, type=field.type) ) - write_records( - payload_arrow_table, - db_client=db_client, - merge_key=target_table_id_field_name, - table_name=fq_table_name, - write_disposition="merge", - remove_orphans=True, - ) + if target_is_root_table: + canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( + job.table_schema + ) + unique_doc_ids = pc.unique(payload_arrow_table[canonical_doc_id_field]).to_pylist() + filter_condition = ( + f"{canonical_doc_id_field} in (" + ",".join(map(str, unique_doc_ids)) + ")" + ) + write_records( + payload_arrow_table, + db_client=db_client, + table_name=fq_table_name, + write_disposition="merge", + merge_key=self._schema.data_item_normalizer.C_DLT_LOAD_ID, # type: ignore[attr-defined] + remove_orphans=True, + filter_condition=filter_condition, + ) + else: + write_records( + payload_arrow_table, + db_client=db_client, + table_name=fq_table_name, + write_disposition="merge", + merge_key=target_table_id_field_name, + remove_orphans=True, + ) @staticmethod def get_parent_path(table_lineage: TTableLineage, table: str) -> Any: diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index ee556b3553..eba3375c9d 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,9 +1,10 @@ import os -from typing import Union, Dict, Optional, TypeVar, Generic, Iterable, Iterator, List +from typing import Union, Dict, List import pyarrow as pa from dlt.common import logger +from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.pendulum import __utcnow from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop @@ -43,41 +44,24 @@ def get_default_arrow_value(field_type: TArrowDataType) -> object: raise ValueError(f"Unsupported data type: {field_type}") -ItemType = TypeVar("ItemType") - - -# LanceDB `merge_insert` expects an 'iter()' method instead of using standard iteration. -# https://github.com/lancedb/lancedb/blob/ae85008714792a6b724c75793b63273c51caba88/python/python/lancedb/table.py#L2264 -class IterableWrapper(Generic[ItemType]): - def __init__(self, iterable: Iterable[ItemType]) -> None: - self.iterable = iterable - - def __iter__(self) -> Iterator[ItemType]: - return iter(self.iterable) - - def iter(self) -> Iterator[ItemType]: # noqa: A003 - return iter(self.iterable) - - -def get_lancedb_orphan_removal_merge_key( +def get_canonical_vector_database_doc_id_merge_key( load_table: TTableSchema, -) -> Optional[Union[str, IterableWrapper[str]]]: +) -> str: if merge_key := get_columns_names_with_prop(load_table, "merge_key"): - return merge_key[0] if len(merge_key) == 1 else IterableWrapper(merge_key) + if len(merge_key) > 1: + raise DestinationTerminalException(f"You cannot specify multiple merge keys with LanceDB orphan remove enabled: {merge_key}") + else: + return merge_key[0] elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): - # No merge key defined, warn and merge on the primary key. - logger.warning( - "Merge strategy selected without defined merge key - using primary key as merge key." - ) - return primary_key[0] if len(primary_key) == 1 else IterableWrapper(merge_key) - elif unique_key := get_columns_names_with_prop(load_table, "unique"): + # No merge key defined, warn and assume the first element of the primary key is `doc_id`. logger.warning( - "Merge strategy selected without defined merge key - using unique key as merge key." + f"Merge strategy selected without defined merge key - using the first element of the primary key ({primary_key}) as merge key." ) - return unique_key[0] if len(unique_key) == 1 else IterableWrapper(unique_key) + return primary_key[0] else: - return None - + raise DestinationTerminalException( + "You must specify at least a primary key in order to perform orphan removal." + ) def fill_empty_source_column_values_with_placeholder( table: pa.Table, source_columns: List[str], placeholder: str diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py index 95b721295b..8c151316af 100644 --- a/tests/load/lancedb/test_merge.py +++ b/tests/load/lancedb/test_merge.py @@ -123,20 +123,10 @@ def identity_resource( child_tbl = client.db_client.open_table(child_table_name) # type: ignore[attr-defined] grandchild_tbl = client.db_client.open_table(grandchild_table_name) # type: ignore[attr-defined] - actual_parent_df = ( - parent_tbl.to_pandas() - .sort_values(by="id") - .reset_index(drop=True) - ) - actual_child_df = ( - child_tbl.to_pandas() - .sort_values(by="bar") - .reset_index(drop=True) - ) + actual_parent_df = parent_tbl.to_pandas().sort_values(by="id").reset_index(drop=True) + actual_child_df = child_tbl.to_pandas().sort_values(by="bar").reset_index(drop=True) actual_grandchild_df = ( - grandchild_tbl.to_pandas() - .sort_values(by="baz") - .reset_index(drop=True) + grandchild_tbl.to_pandas().sort_values(by="baz").reset_index(drop=True) ) expected_parent_data = expected_parent_data.sort_values(by="id").reset_index(drop=True) @@ -169,6 +159,8 @@ def identity_resource( ) -> Generator[List[DictStrAny], None, None]: yield data + lancedb_adapter(identity_resource) + run_1 = [ {"doc_id": 1, "chunk_hash": "1a"}, {"doc_id": 2, "chunk_hash": "2a"}, @@ -212,6 +204,70 @@ def identity_resource( assert_frame_equal(actual_root_df, expected_root_table_df) +def test_lancedb_remove_orphaned_records_root_table_string_doc_id() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_root_table", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_hash"], + merge_key=["doc_id"], + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + lancedb_adapter(identity_resource) + + run_1 = [ + {"doc_id": "A", "chunk_hash": "1a"}, + {"doc_id": "B", "chunk_hash": "2a"}, + {"doc_id": "B", "chunk_hash": "2b"}, + {"doc_id": "B", "chunk_hash": "2c"}, + {"doc_id": "C", "chunk_hash": "3a"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": "B", "chunk_hash": "2d"}, + {"doc_id": "B", "chunk_hash": "2e"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": "A", "chunk_hash": "1a"}, + {"doc_id": "B", "chunk_hash": "2d"}, + {"doc_id": "B", "chunk_hash": "2e"}, + {"doc_id": "C", "chunk_hash": "3b"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) + + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) + )[["doc_id", "chunk_hash"]] + + assert_frame_equal(actual_root_df, expected_root_table_df) + + def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: @dlt.resource( write_disposition={"disposition": "merge", "strategy": "upsert"}, From 5ceeda984a72a25cfc6ab459c3dfdfe0baf654a7 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 17:09:13 +0200 Subject: [PATCH 097/113] Cater for string ids as well in doc_id removal process Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 5 ++--- dlt/destinations/impl/lancedb/utils.py | 19 +++++++++++++++++-- tests/load/lancedb/test_merge.py | 2 +- tests/load/lancedb/test_pipeline.py | 2 +- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 1a29d7cbe9..72f889be2d 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -79,6 +79,7 @@ EMPTY_STRING_PLACEHOLDER, fill_empty_source_column_values_with_placeholder, get_canonical_vector_database_doc_id_merge_key, + create_filter_condition, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -848,9 +849,7 @@ def run(self) -> None: job.table_schema ) unique_doc_ids = pc.unique(payload_arrow_table[canonical_doc_id_field]).to_pylist() - filter_condition = ( - f"{canonical_doc_id_field} in (" + ",".join(map(str, unique_doc_ids)) + ")" - ) + filter_condition = create_filter_condition(canonical_doc_id_field, unique_doc_ids) write_records( payload_arrow_table, db_client=db_client, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index eba3375c9d..610a5b4837 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -49,13 +49,17 @@ def get_canonical_vector_database_doc_id_merge_key( ) -> str: if merge_key := get_columns_names_with_prop(load_table, "merge_key"): if len(merge_key) > 1: - raise DestinationTerminalException(f"You cannot specify multiple merge keys with LanceDB orphan remove enabled: {merge_key}") + raise DestinationTerminalException( + "You cannot specify multiple merge keys with LanceDB orphan remove enabled:" + f" {merge_key}" + ) else: return merge_key[0] elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): # No merge key defined, warn and assume the first element of the primary key is `doc_id`. logger.warning( - f"Merge strategy selected without defined merge key - using the first element of the primary key ({primary_key}) as merge key." + "Merge strategy selected without defined merge key - using the first element of the" + f" primary key ({primary_key}) as merge key." ) return primary_key[0] else: @@ -63,6 +67,7 @@ def get_canonical_vector_database_doc_id_merge_key( "You must specify at least a primary key in order to perform orphan removal." ) + def fill_empty_source_column_values_with_placeholder( table: pa.Table, source_columns: List[str], placeholder: str ) -> pa.Table: @@ -85,3 +90,13 @@ def fill_empty_source_column_values_with_placeholder( ) table = table.set_column(table.column_names.index(col_name), col_name, new_column) return table + + +def create_filter_condition( + canonical_doc_id_field: str, unique_doc_ids: List[Union[str, int, float]] +) -> str: + def format_value(x: Union[str, int, float]) -> str: + return f"'{x}'" if isinstance(x, str) else str(x) + + formatted_ids = ", ".join(map(format_value, unique_doc_ids)) + return f"{canonical_doc_id_field} IN ({formatted_ids})" diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py index 8c151316af..519014f081 100644 --- a/tests/load/lancedb/test_merge.py +++ b/tests/load/lancedb/test_merge.py @@ -33,7 +33,7 @@ def drop_lancedb_data() -> Iterator[None]: drop_active_pipeline_data() -def test_lancedb_remove_orphaned_records() -> None: +def test_lancedb_remove_nested_orphaned_records() -> None: pipeline = dlt.pipeline( pipeline_name="test_lancedb_remove_orphaned_records", destination="lancedb", diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index aef36aa315..dcbe0eb04e 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -427,7 +427,7 @@ def test_empty_dataset_allowed() -> None: assert_table(pipe, "content", expected_items_count=3) -def test_merge_no_orphans() -> None: +def test_lancedb_remove_nested_orphaned_records_with_chunks() -> None: @dlt.resource( write_disposition={"disposition": "merge", "strategy": "upsert"}, table_name="document", From a8f9c3b23009f4bc6e678e90d99c50adf690dec8 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 17:12:26 +0200 Subject: [PATCH 098/113] Fix test with wrong primary key Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py index 519014f081..bcde2dc93e 100644 --- a/tests/load/lancedb/test_merge.py +++ b/tests/load/lancedb/test_merge.py @@ -272,7 +272,7 @@ def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> No @dlt.resource( write_disposition={"disposition": "merge", "strategy": "upsert"}, table_name="document", - primary_key="doc_id", + primary_key=["doc_id", "chunk"], merge_key="doc_id", ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: From b3baf93d356e212dad30a5f01cd80abf991d37dd Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 17:26:50 +0200 Subject: [PATCH 099/113] Just send list of ids as is. don't pc.compute on client end Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 7 ++++--- dlt/destinations/impl/lancedb/utils.py | 13 ++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 72f889be2d..54ef76c880 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -17,7 +17,6 @@ import lancedb # type: ignore import lancedb.table # type: ignore import pyarrow as pa -import pyarrow.compute as pc import pyarrow.parquet as pq from lancedb import DBConnection from lancedb.common import DATA # type: ignore @@ -848,8 +847,10 @@ def run(self) -> None: canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( job.table_schema ) - unique_doc_ids = pc.unique(payload_arrow_table[canonical_doc_id_field]).to_pylist() - filter_condition = create_filter_condition(canonical_doc_id_field, unique_doc_ids) + # TODO: Guard against edge cases. For example, if `doc_id` field has escape characters in it. + filter_condition = create_filter_condition( + canonical_doc_id_field, payload_arrow_table[canonical_doc_id_field] + ) write_records( payload_arrow_table, db_client=db_client, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 610a5b4837..e4f3b1f90d 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -92,11 +92,10 @@ def fill_empty_source_column_values_with_placeholder( return table -def create_filter_condition( - canonical_doc_id_field: str, unique_doc_ids: List[Union[str, int, float]] -) -> str: - def format_value(x: Union[str, int, float]) -> str: - return f"'{x}'" if isinstance(x, str) else str(x) +def create_filter_condition(canonical_doc_id_field: str, id_column: pa.Array) -> str: + def format_value(element: Union[str, int, float, pa.Scalar]) -> str: + if isinstance(element, pa.Scalar): + element = element.as_py() + return "'" + element.replace("'", "''") + "'" if isinstance(element, str) else str(element) - formatted_ids = ", ".join(map(format_value, unique_doc_ids)) - return f"{canonical_doc_id_field} IN ({formatted_ids})" + return f"{canonical_doc_id_field} IN ({', '.join(map(format_value, id_column))})" From 589071ce03aafff0a2cb1d4763afcc7e2f55f03b Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 17:44:01 +0200 Subject: [PATCH 100/113] Extract schema matching into utils Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 31 ++++---------- dlt/destinations/impl/lancedb/utils.py | 40 +++++++++++++++++++ 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 54ef76c880..1f1206995a 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -79,6 +79,7 @@ fill_empty_source_column_values_with_placeholder, get_canonical_vector_database_doc_id_merge_key, create_filter_condition, + add_missing_columns_to_arrow_table, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -812,36 +813,20 @@ def run(self) -> None: fq_table_name = self._job_client.make_qualified_table_name(job.table_name) if target_is_root_table: - target_table_id_field_name = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] file_path = job.file_path else: - target_table_id_field_name = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined] - file_path = self.get_parent_path(table_lineage, job.table_schema.get("parent")) + file_path = self.get_parent_paths(table_lineage, job.table_schema.get("parent")) with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: payload_arrow_table: pa.Table = pq.read_table(f) - # Get target table schema + # Get target table schema. with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: target_table_schema: pa.Schema = pq.read_schema(f) - # LanceDB requires the payload to have all fields populated, even if we don't intend to use them in our merge operation. - # Unfortunately, we can't just create NULL fields; else LanceDB always truncates the target using `when_not_matched_by_source_delete`. - schema_difference = pa.schema( - set(target_table_schema) - set(payload_arrow_table.schema) + payload_arrow_table = add_missing_columns_to_arrow_table( + payload_arrow_table, target_table_schema ) - for field in schema_difference: - try: - default_value = get_default_arrow_value(field.type) - default_array = pa.array( - [default_value] * payload_arrow_table.num_rows, type=field.type - ) - payload_arrow_table = payload_arrow_table.append_column(field, default_array) - except ValueError as e: - logger.warn(f"{e}. Using null values for field '{field.name}'.") - payload_arrow_table = payload_arrow_table.append_column( - field, pa.nulls(size=payload_arrow_table.num_rows, type=field.type) - ) if target_is_root_table: canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( @@ -866,10 +851,10 @@ def run(self) -> None: db_client=db_client, table_name=fq_table_name, write_disposition="merge", - merge_key=target_table_id_field_name, + merge_key=self._schema.data_item_normalizer.C_DLT_ROOT_ID, # type: ignore[attr-defined] remove_orphans=True, ) @staticmethod - def get_parent_path(table_lineage: TTableLineage, table: str) -> Any: - return next(entry.file_path for entry in table_lineage if entry.table_name == table) + def get_parent_paths(table_lineage: TTableLineage, table: str) -> List[str]: + return [entry.file_path for entry in table_lineage if entry.table_name == table][0] diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index e4f3b1f90d..d107475fe7 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -99,3 +99,43 @@ def format_value(element: Union[str, int, float, pa.Scalar]) -> str: return "'" + element.replace("'", "''") + "'" if isinstance(element, str) else str(element) return f"{canonical_doc_id_field} IN ({', '.join(map(format_value, id_column))})" + + +def add_missing_columns_to_arrow_table( + payload_arrow_table: pa.Table, + target_table_schema: pa.Schema, +) -> pa.Table: + """Add missing columns from the target schema to the payload Arrow table. + + LanceDB requires the payload to have all fields populated, even if we + don't intend to use them in our merge operation. + + Unfortunately, we can't just create NULL fields; else LanceDB always truncates + the target using `when_not_matched_by_source_delete`. + This function identifies columns present in the target schema but missing from + the payload table and adds them with either default or null values. + + Args: + payload_arrow_table: The input Arrow table. + target_table_schema: The schema of the target table. + + Returns: + The modified Arrow table with added columns. + + """ + schema_difference = pa.schema(set(target_table_schema) - set(payload_arrow_table.schema)) + + for field in schema_difference: + try: + default_value = get_default_arrow_value(field.type) + default_array = pa.array( + [default_value] * payload_arrow_table.num_rows, type=field.type + ) + payload_arrow_table = payload_arrow_table.append_column(field, default_array) + except ValueError as e: + logger.warning(f"{e}. Using null values for field '{field.name}'.") + payload_arrow_table = payload_arrow_table.append_column( + field, pa.nulls(size=payload_arrow_table.num_rows, type=field.type) + ) + + return payload_arrow_table From a86a13a58ff0fa0a57567abdbb36baad21eb28ef Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 2 Sep 2024 23:04:25 +0200 Subject: [PATCH 101/113] Add utils Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 63 ++++++++++++------- dlt/destinations/impl/lancedb/utils.py | 10 +++ tests/load/lancedb/test_merge.py | 2 + 3 files changed, 51 insertions(+), 24 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 1f1206995a..53a35ecfb2 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -74,12 +74,12 @@ ) from dlt.destinations.impl.lancedb.utils import ( set_non_standard_providers_environment_variables, - get_default_arrow_value, EMPTY_STRING_PLACEHOLDER, fill_empty_source_column_values_with_placeholder, get_canonical_vector_database_doc_id_merge_key, create_filter_condition, add_missing_columns_to_arrow_table, + get_root_table_name, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -814,21 +814,14 @@ def run(self) -> None: if target_is_root_table: file_path = job.file_path - else: - file_path = self.get_parent_paths(table_lineage, job.table_schema.get("parent")) - - with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: - payload_arrow_table: pa.Table = pq.read_table(f) - - # Get target table schema. - with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: - target_table_schema: pa.Schema = pq.read_schema(f) + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + target_table_schema: pa.Schema = pq.read_schema(f) - payload_arrow_table = add_missing_columns_to_arrow_table( - payload_arrow_table, target_table_schema - ) + payload_arrow_table = add_missing_columns_to_arrow_table( + payload_arrow_table, target_table_schema + ) - if target_is_root_table: canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( job.table_schema ) @@ -845,16 +838,38 @@ def run(self) -> None: remove_orphans=True, filter_condition=filter_condition, ) + else: - write_records( - payload_arrow_table, - db_client=db_client, - table_name=fq_table_name, - write_disposition="merge", - merge_key=self._schema.data_item_normalizer.C_DLT_ROOT_ID, # type: ignore[attr-defined] - remove_orphans=True, - ) + # Use root table load history to identify orphans. + root_table_name = get_root_table_name(job.table_schema, self._schema) + root_table_file_paths = self.get_table_paths(table_lineage, root_table_name) + for file_path in root_table_file_paths: + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) # type: ignore[no-redef] + with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: + target_table_schema: pa.Schema = pq.read_schema(f) # type: ignore[no-redef] + + # Merge key needs to be same for both parent and child. + # Since we intend to merge on dlt_root_id, we rename the parent table dlt_id. + payload_arrow_table_names: List[str] = payload_arrow_table.column_names + payload_arrow_table_names[ + payload_arrow_table_names.index(self._schema.data_item_normalizer.C_DLT_ID) # type: ignore[attr-defined] + ] = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined] + payload_arrow_table = payload_arrow_table.rename_columns( + payload_arrow_table_names + ) + payload_arrow_table = add_missing_columns_to_arrow_table( + payload_arrow_table, target_table_schema + ) + write_records( + payload_arrow_table, + db_client=db_client, + table_name=fq_table_name, + write_disposition="merge", + merge_key=self._schema.data_item_normalizer.C_DLT_ROOT_ID, # type: ignore[attr-defined] + remove_orphans=True, + ) @staticmethod - def get_parent_paths(table_lineage: TTableLineage, table: str) -> List[str]: - return [entry.file_path for entry in table_lineage if entry.table_name == table][0] + def get_table_paths(table_lineage: TTableLineage, table: str) -> List[str]: + return [entry.file_path for entry in table_lineage if entry.table_name == table] diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index d107475fe7..2952c6ee1f 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -3,6 +3,7 @@ import pyarrow as pa +from dlt import Schema from dlt.common import logger from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.pendulum import __utcnow @@ -139,3 +140,12 @@ def add_missing_columns_to_arrow_table( ) return payload_arrow_table + + +def get_root_table_name(table: TTableSchema, schema: Schema) -> str: + """Identify a table's root table.""" + if parent_name := table.get("parent"): + parent = schema.get_table(parent_name) + return get_root_table_name(parent, schema) + else: + return table["name"] diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py index bcde2dc93e..ca1060661a 100644 --- a/tests/load/lancedb/test_merge.py +++ b/tests/load/lancedb/test_merge.py @@ -52,6 +52,8 @@ def identity_resource( ) -> Generator[List[DictStrAny], None, None]: yield data + # lancedb_adapter(identity_resource, no_remove_orphans=True) + run_1 = [ { "id": 1, From 0eba25e656e8b32cd22d53486b4a29305148d3f4 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 3 Sep 2024 01:01:53 +0200 Subject: [PATCH 102/113] Pass all tests Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 82 ++++++------------- dlt/destinations/impl/lancedb/utils.py | 13 +-- tests/load/lancedb/test_merge.py | 6 +- 3 files changed, 30 insertions(+), 71 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 53a35ecfb2..7892a8c2c7 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -79,7 +79,6 @@ get_canonical_vector_database_doc_id_merge_key, create_filter_condition, add_missing_columns_to_arrow_table, - get_root_table_name, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -761,13 +760,11 @@ def run(self) -> None: arrow_table, source_columns, EMPTY_STRING_PLACEHOLDER ) - merge_key = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] - # We need upsert merge's deterministic _dlt_id to perform orphan removal. # Hence, we require at least a primary key on the root table if the merge disposition is chosen. if ( (self._load_table not in self._schema.dlt_table_names()) - and not self._load_table.get("parent") + and not self._load_table.get("parent") # Is root table. and (write_disposition == "merge") and (not get_columns_names_with_prop(self._load_table, "primary_key")) ): @@ -780,7 +777,7 @@ def run(self) -> None: db_client=db_client, table_name=fq_table_name, write_disposition=write_disposition, - merge_key=merge_key, + merge_key=self._schema.data_item_normalizer.C_DLT_ID, # type: ignore[attr-defined] ) @@ -811,17 +808,16 @@ def run(self) -> None: for job in table_lineage: target_is_root_table = "parent" not in job.table_schema fq_table_name = self._job_client.make_qualified_table_name(job.table_name) + file_path = job.file_path + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + target_table_schema: pa.Schema = pq.read_schema(f) - if target_is_root_table: - file_path = job.file_path - with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: - payload_arrow_table: pa.Table = pq.read_table(f) - target_table_schema: pa.Schema = pq.read_schema(f) - - payload_arrow_table = add_missing_columns_to_arrow_table( - payload_arrow_table, target_table_schema - ) + payload_arrow_table = add_missing_columns_to_arrow_table( + payload_arrow_table, target_table_schema + ) + if target_is_root_table: canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( job.table_schema ) @@ -829,47 +825,21 @@ def run(self) -> None: filter_condition = create_filter_condition( canonical_doc_id_field, payload_arrow_table[canonical_doc_id_field] ) - write_records( - payload_arrow_table, - db_client=db_client, - table_name=fq_table_name, - write_disposition="merge", - merge_key=self._schema.data_item_normalizer.C_DLT_LOAD_ID, # type: ignore[attr-defined] - remove_orphans=True, - filter_condition=filter_condition, - ) + merge_key = self._schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] else: - # Use root table load history to identify orphans. - root_table_name = get_root_table_name(job.table_schema, self._schema) - root_table_file_paths = self.get_table_paths(table_lineage, root_table_name) - for file_path in root_table_file_paths: - with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: - payload_arrow_table: pa.Table = pq.read_table(f) # type: ignore[no-redef] - with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: - target_table_schema: pa.Schema = pq.read_schema(f) # type: ignore[no-redef] - - # Merge key needs to be same for both parent and child. - # Since we intend to merge on dlt_root_id, we rename the parent table dlt_id. - payload_arrow_table_names: List[str] = payload_arrow_table.column_names - payload_arrow_table_names[ - payload_arrow_table_names.index(self._schema.data_item_normalizer.C_DLT_ID) # type: ignore[attr-defined] - ] = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined] - payload_arrow_table = payload_arrow_table.rename_columns( - payload_arrow_table_names - ) - payload_arrow_table = add_missing_columns_to_arrow_table( - payload_arrow_table, target_table_schema - ) - write_records( - payload_arrow_table, - db_client=db_client, - table_name=fq_table_name, - write_disposition="merge", - merge_key=self._schema.data_item_normalizer.C_DLT_ROOT_ID, # type: ignore[attr-defined] - remove_orphans=True, - ) - - @staticmethod - def get_table_paths(table_lineage: TTableLineage, table: str) -> List[str]: - return [entry.file_path for entry in table_lineage if entry.table_name == table] + filter_condition = create_filter_condition( + self._schema.data_item_normalizer.C_DLT_ROOT_ID, # type: ignore[attr-defined] + payload_arrow_table[self._schema.data_item_normalizer.C_DLT_ROOT_ID], # type: ignore[attr-defined] + ) + merge_key = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] + + write_records( + payload_arrow_table, + db_client=db_client, + table_name=fq_table_name, + write_disposition="merge", + merge_key=merge_key, + remove_orphans=True, + filter_condition=filter_condition, + ) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 2952c6ee1f..525f1cec7a 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -93,13 +93,13 @@ def fill_empty_source_column_values_with_placeholder( return table -def create_filter_condition(canonical_doc_id_field: str, id_column: pa.Array) -> str: +def create_filter_condition(field_name: str, array: pa.Array) -> str: def format_value(element: Union[str, int, float, pa.Scalar]) -> str: if isinstance(element, pa.Scalar): element = element.as_py() return "'" + element.replace("'", "''") + "'" if isinstance(element, str) else str(element) - return f"{canonical_doc_id_field} IN ({', '.join(map(format_value, id_column))})" + return f"{field_name} IN ({', '.join(map(format_value, array))})" def add_missing_columns_to_arrow_table( @@ -140,12 +140,3 @@ def add_missing_columns_to_arrow_table( ) return payload_arrow_table - - -def get_root_table_name(table: TTableSchema, schema: Schema) -> str: - """Identify a table's root table.""" - if parent_name := table.get("parent"): - parent = schema.get_table(parent_name) - return get_root_table_name(parent, schema) - else: - return table["name"] diff --git a/tests/load/lancedb/test_merge.py b/tests/load/lancedb/test_merge.py index ca1060661a..f04c846df7 100644 --- a/tests/load/lancedb/test_merge.py +++ b/tests/load/lancedb/test_merge.py @@ -52,8 +52,6 @@ def identity_resource( ) -> Generator[List[DictStrAny], None, None]: yield data - # lancedb_adapter(identity_resource, no_remove_orphans=True) - run_1 = [ { "id": 1, @@ -78,11 +76,11 @@ def identity_resource( { "id": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}], - }, # Removed one child and one grandchild. + }, # Removes bar_2, baz_2 and baz_3. { "id": 2, "child": [{"bar": 4, "grandchild": [{"baz": 8}]}], - }, # Changed child and grandchild. + }, # Removes bar_3, baz_4. ] info = pipeline.run(identity_resource(run_2)) assert_load_info(info) From 2b7f4c6d2a72399123a2c20da05d136b76b9f177 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 3 Sep 2024 01:13:52 +0200 Subject: [PATCH 103/113] Minor format and cleanup Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 19 +++--- dlt/destinations/impl/lancedb/utils.py | 60 ------------------- 2 files changed, 8 insertions(+), 71 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 7892a8c2c7..11249d0f97 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -78,7 +78,6 @@ fill_empty_source_column_values_with_placeholder, get_canonical_vector_database_doc_id_merge_key, create_filter_condition, - add_missing_columns_to_arrow_table, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -793,6 +792,10 @@ def __init__( self.references = ReferenceFollowupJobRequest.resolve_references(file_path) def run(self) -> None: + dlt_load_id = self._schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] + dlt_id = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] + dlt_root_id = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined] + db_client: DBConnection = self._job_client.db_client table_lineage: TTableLineage = [ TableJob( @@ -811,28 +814,22 @@ def run(self) -> None: file_path = job.file_path with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: payload_arrow_table: pa.Table = pq.read_table(f) - target_table_schema: pa.Schema = pq.read_schema(f) - - payload_arrow_table = add_missing_columns_to_arrow_table( - payload_arrow_table, target_table_schema - ) if target_is_root_table: canonical_doc_id_field = get_canonical_vector_database_doc_id_merge_key( job.table_schema ) - # TODO: Guard against edge cases. For example, if `doc_id` field has escape characters in it. filter_condition = create_filter_condition( canonical_doc_id_field, payload_arrow_table[canonical_doc_id_field] ) - merge_key = self._schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] + merge_key = dlt_load_id else: filter_condition = create_filter_condition( - self._schema.data_item_normalizer.C_DLT_ROOT_ID, # type: ignore[attr-defined] - payload_arrow_table[self._schema.data_item_normalizer.C_DLT_ROOT_ID], # type: ignore[attr-defined] + dlt_root_id, + payload_arrow_table[dlt_root_id], ) - merge_key = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] + merge_key = dlt_id write_records( payload_arrow_table, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 525f1cec7a..f07f2754d2 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -3,14 +3,11 @@ import pyarrow as pa -from dlt import Schema from dlt.common import logger from dlt.common.destination.exceptions import DestinationTerminalException -from dlt.common.pendulum import __utcnow from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider -from dlt.destinations.impl.lancedb.schema import TArrowDataType EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -28,23 +25,6 @@ def set_non_standard_providers_environment_variables( os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" -def get_default_arrow_value(field_type: TArrowDataType) -> object: - if pa.types.is_integer(field_type): - return 0 - elif pa.types.is_floating(field_type): - return 0.0 - elif pa.types.is_string(field_type): - return "" - elif pa.types.is_boolean(field_type): - return False - elif pa.types.is_date(field_type): - return __utcnow().today() - elif pa.types.is_timestamp(field_type): - return __utcnow() - else: - raise ValueError(f"Unsupported data type: {field_type}") - - def get_canonical_vector_database_doc_id_merge_key( load_table: TTableSchema, ) -> str: @@ -100,43 +80,3 @@ def format_value(element: Union[str, int, float, pa.Scalar]) -> str: return "'" + element.replace("'", "''") + "'" if isinstance(element, str) else str(element) return f"{field_name} IN ({', '.join(map(format_value, array))})" - - -def add_missing_columns_to_arrow_table( - payload_arrow_table: pa.Table, - target_table_schema: pa.Schema, -) -> pa.Table: - """Add missing columns from the target schema to the payload Arrow table. - - LanceDB requires the payload to have all fields populated, even if we - don't intend to use them in our merge operation. - - Unfortunately, we can't just create NULL fields; else LanceDB always truncates - the target using `when_not_matched_by_source_delete`. - This function identifies columns present in the target schema but missing from - the payload table and adds them with either default or null values. - - Args: - payload_arrow_table: The input Arrow table. - target_table_schema: The schema of the target table. - - Returns: - The modified Arrow table with added columns. - - """ - schema_difference = pa.schema(set(target_table_schema) - set(payload_arrow_table.schema)) - - for field in schema_difference: - try: - default_value = get_default_arrow_value(field.type) - default_array = pa.array( - [default_value] * payload_arrow_table.num_rows, type=field.type - ) - payload_arrow_table = payload_arrow_table.append_column(field, default_array) - except ValueError as e: - logger.warning(f"{e}. Using null values for field '{field.name}'.") - payload_arrow_table = payload_arrow_table.append_column( - field, pa.nulls(size=payload_arrow_table.num_rows, type=field.type) - ) - - return payload_arrow_table From ea36b00e44af25aecae6e7f1353dc59fa1e493c0 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 3 Sep 2024 02:01:45 +0200 Subject: [PATCH 104/113] Docs Signed-off-by: Marcel Coetzee --- .../dlt-ecosystem/destinations/lancedb.md | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index c6dcd16862..ba72a98acc 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -181,37 +181,60 @@ info = pipeline.run( ### Merge -The [merge](../../general-usage/incremental-loading.md) write disposition merges the data from the resource with the data at the destination based on a unique identifier. +The [merge](../../general-usage/incremental-loading.md) write disposition merges the data from the resource with the data at the destination based on a unique identifier. The LanceDB destination merge write disposition only supports upsert strategy. This updates existing records and inserts new ones based on a unique identifier. + +You can specify the merge disposition, primary key, and merge key either in a resource or adapter: + +```py +@dlt.resource( + primary_key=["doc_id", "chunk_id"], + merge_key=["doc_id"], + write_disposition={"disposition": "merge", "strategy": "upsert"}, +) +def my_rag_docs( + data: List[DictStrAny], +) -> Generator[List[DictStrAny], None, None]: + yield data +``` + +Or: ```py pipeline.run( lancedb_adapter( - movies, - embed="title", + my_new_rag_docs, + merge_key="doc_id" ), - write_disposition="merge", - primary_key="id", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_id"], ) ``` +The `primary_key` uniquely identifies each record, typically comprising a document ID and a chunk ID. +The `merge_key`, which cannot be compound, should correspond to the canonical `doc_id` used in vector databases and represent the document identifier in your data model. +It must be the first element of the `primary_key`. +This `merge_key` is crucial for document identification and orphan removal during merge operations. +This structure ensures proper record identification and maintains consistency with vector database concepts. + + #### Orphan Removal -To maintain referential integrity between parent document tables and chunk tables, you can automatically remove orphaned chunks when updating or deleting parent documents. -Specify the "x-lancedb-doc-id" hint as follows: +LanceDB **automatically removes orphaned chunks** when updating or deleting parent documents during a merge operation. To disable this feature: ```py pipeline.run( lancedb_adapter( movies, embed="title", - document_id="id" + no_remove_orphans=True # Disable with the `no_remove_orphans` flag. ), - write_disposition="merge", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + primary_key=["doc_id", "chunk_id"], ) ``` -This sets `document_id` as the primary key and uses it to remove orphans in root tables and child tables recursively. -While it's technically possible to set both a primary key, and the `document_id` hint separately, doing so leads to confusing behavior and should be avoided. +Note: While it's possible to omit the `merge_key` for brevity (in which case it is assumed to be the first entry of `primary_key`), +explicitly specifying both is recommended for clarity. ### Append From 81eaea924e33fa12b2d66ab54ff9c12dbb85cb3b Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 5 Sep 2024 15:21:31 +0200 Subject: [PATCH 105/113] Amend replace test to test with large number of records to catch race conditions with replace disposition Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_pipeline.py | 14 ++++++-------- tests/load/utils.py | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index dcbe0eb04e..89b3b4b3bc 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,4 +1,5 @@ import multiprocessing +import os from typing import Iterator, Generator, Any, List from typing import Mapping from typing import Union, Dict @@ -173,25 +174,22 @@ def some_data() -> Generator[List[DictStrAny], Any, None]: def test_pipeline_replace() -> None: - generator_instance1 = sequence_generator() - generator_instance2 = sequence_generator() + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "2" + os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "2" + + generator_instance1, generator_instance2 = (sequence_generator(), sequence_generator()) @dlt.resource def some_data() -> Generator[DictStrStr, Any, None]: yield from next(generator_instance1) - lancedb_adapter( - some_data, - embed=["content"], - ) - uid = uniq_id() pipeline = dlt.pipeline( pipeline_name="test_pipeline_replace", destination="lancedb", dataset_name="test_pipeline_replace_dataset" - + uid, # lancedb doesn't mandate any name normalization + + uid, # Lancedb doesn't mandate any name normalization. ) info = pipeline.run( diff --git a/tests/load/utils.py b/tests/load/utils.py index 5427904d52..008ac9ee95 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -938,7 +938,7 @@ def prepare_load_package( def sequence_generator() -> Generator[List[Dict[str, str]], None, None]: count = 1 while True: - yield [{"content": str(count + i)} for i in range(3)] + yield [{"content": str(count + i)} for i in range(1000)] count += 3 From f6d243a1a477cd7914b8cb47a01138532927dc90 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 5 Sep 2024 16:01:06 +0200 Subject: [PATCH 106/113] Fix replace race conditions by delegating truncation to dlt Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 4 +--- tests/load/lancedb/utils.py | 2 +- tests/load/utils.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 11249d0f97..c23161e427 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -201,10 +201,8 @@ def write_records( ) from e try: - if write_disposition in ("append", "skip"): + if write_disposition in ("append", "skip", "replace"): tbl.add(records) - elif write_disposition == "replace": - tbl.add(records, mode="overwrite") elif write_disposition == "merge": if remove_orphans: tbl.merge_insert(merge_key).when_not_matched_by_source_delete( diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index 8e2fddfba5..30430fe076 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -40,7 +40,7 @@ def assert_table( exists = client.table_exists(qualified_table_name) assert exists - records = client.db_client.open_table(qualified_table_name).search().limit(50).to_list() + records = client.db_client.open_table(qualified_table_name).search().limit(0).to_list() if expected_items_count is not None: assert expected_items_count == len(records) diff --git a/tests/load/utils.py b/tests/load/utils.py index 008ac9ee95..42c2ddd722 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -938,7 +938,7 @@ def prepare_load_package( def sequence_generator() -> Generator[List[Dict[str, str]], None, None]: count = 1 while True: - yield [{"content": str(count + i)} for i in range(1000)] + yield [{"content": str(count + i)} for i in range(2000)] count += 3 From f32d4cde9b7e5521596fc98ba71a477116039393 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 8 Sep 2024 17:39:33 +0200 Subject: [PATCH 107/113] Update lock file Signed-off-by: Marcel Coetzee --- poetry.lock | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 0ce139d08f..2f73de3494 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3876,6 +3876,17 @@ files = [ [package.extras] test = ["pytest", "sphinx", "sphinx-autobuild", "twine", "wheel"] +[[package]] +name = "graphlib-backport" +version = "1.1.0" +description = "Backport of the Python 3.9 graphlib module for Python 3.6+" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "graphlib_backport-1.1.0-py3-none-any.whl", hash = "sha256:eccacf9f2126cdf89ce32a6018c88e1ecd3e4898a07568add6e1907a439055ba"}, + {file = "graphlib_backport-1.1.0.tar.gz", hash = "sha256:00a7888b21e5393064a133209cb5d3b3ef0a2096cf023914c9d778dff5644125"}, +] + [[package]] name = "greenlet" version = "3.0.3" @@ -5187,6 +5198,17 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mimesis" +version = "7.1.0" +description = "Mimesis: Fake Data Generator." +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "mimesis-7.1.0-py3-none-any.whl", hash = "sha256:da65bea6d6d5d5d87d5c008e6b23ef5f96a49cce436d9f8708dabb5152da0290"}, + {file = "mimesis-7.1.0.tar.gz", hash = "sha256:c83b55d35536d7e9b9700a596b7ccfb639a740e3e1fb5e08062e8ab2a67dcb37"}, +] + [[package]] name = "minimal-snowplow-tracker" version = "0.0.2" @@ -9720,10 +9742,11 @@ qdrant = ["qdrant-client"] redshift = ["psycopg2-binary", "psycopg2cffi"] s3 = ["botocore", "s3fs"] snowflake = ["snowflake-connector-python"] +sql-database = ["sqlalchemy"] synapse = ["adlfs", "pyarrow", "pyodbc"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "1d8fa59c9ef876d699cb5b5a2fcadb9a78c4c4d28a9fca7ca0e83147c08feaae" +content-hash = "7c76d7d00be7aeacdf21defe2caeb8f2f617ce20a067b3d760ec6fd5542097fb" From 7bd2e9c4da37e55d157fffa7e8e7a32ed9384258 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 24 Sep 2024 20:43:29 +0200 Subject: [PATCH 108/113] Refactor type mapping and schema handling in LanceDB client Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 30 ++++++++++--------- dlt/destinations/impl/lancedb/schema.py | 8 ++--- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index ee944133d6..f58f0e371b 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -14,7 +14,6 @@ Set, ) -from dlt.common.destination.capabilities import DataTypeMapper import lancedb # type: ignore import lancedb.table # type: ignore import pyarrow as pa @@ -47,11 +46,11 @@ from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import ( - C_DLT_LOAD_ID, TColumnType, TTableSchemaColumns, TWriteDisposition, TColumnSchema, + TTableSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName @@ -83,7 +82,7 @@ create_filter_condition, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest -from dlt.destinations.type_mapping import TypeMapper +from dlt.destinations.type_mapping import TypeMapperImpl if TYPE_CHECKING: NDArray = ndarray[Any, Any] @@ -96,7 +95,7 @@ EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" -class LanceDBTypeMapper(TypeMapper): +class LanceDBTypeMapper(TypeMapperImpl): sct_to_unbound_dbt = { "text": pa.string(), "double": pa.float64(), @@ -738,14 +737,8 @@ def __init__( self, file_path: str, table_schema: TTableSchema, - type_mapper: DataTypeMapper, - model_func: TextEmbeddingFunction, - fq_table_name: str, ) -> None: super().__init__(file_path) - self._type_mapper = type_mapper - self._fq_table_name: str = fq_table_name - self._model_func = model_func self._job_client: "LanceDBClient" = None self._table_schema: TTableSchema = table_schema @@ -780,12 +773,15 @@ def run(self) -> None: "LanceDB's write disposition requires at least one explicit primary key." ) + dlt_id = self._schema.naming.normalize_identifier( + self._schema.data_item_normalizer.C_DLT_ID + ) write_records( arrow_table, db_client=db_client, table_name=fq_table_name, write_disposition=write_disposition, - merge_key=self._schema.data_item_normalizer.C_DLT_ID, # type: ignore[attr-defined] + merge_key=dlt_id, ) @@ -801,9 +797,15 @@ def __init__( self.references = ReferenceFollowupJobRequest.resolve_references(file_path) def run(self) -> None: - dlt_load_id = self._schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] - dlt_id = self._schema.data_item_normalizer.C_DLT_ID # type: ignore[attr-defined] - dlt_root_id = self._schema.data_item_normalizer.C_DLT_ROOT_ID # type: ignore[attr-defined] + dlt_load_id = self._schema.naming.normalize_identifier( + self._schema.data_item_normalizer.C_DLT_LOAD_ID + ) + dlt_id = self._schema.naming.normalize_identifier( + self._schema.data_item_normalizer.C_DLT_ID + ) + dlt_root_id = self._schema.naming.normalize_identifier( + self._schema.data_item_normalizer.C_DLT_ROOT_ID + ) db_client: DBConnection = self._job_client.db_client table_lineage: TTableLineage = [ diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index 1d7c62e420..25dfbc840a 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -4,19 +4,17 @@ List, cast, Optional, - Tuple, ) import pyarrow as pa from lancedb.embeddings import TextEmbeddingFunction # type: ignore from typing_extensions import TypeAlias +from dlt.common.destination.capabilities import DataTypeMapper from dlt.common.json import json -from dlt.common.schema import Schema, TColumnSchema, TTableSchema +from dlt.common.schema import Schema, TColumnSchema from dlt.common.typing import DictStrAny -from dlt.common.destination.capabilities import DataTypeMapper - TArrowSchema: TypeAlias = pa.Schema TArrowDataType: TypeAlias = pa.DataType @@ -44,9 +42,7 @@ def make_arrow_field_schema( def make_arrow_table_schema( table_name: str, schema: Schema, - type_mapper: TypeMapper, type_mapper: DataTypeMapper, - id_field_name: Optional[str] = None, vector_field_name: Optional[str] = None, embedding_fields: Optional[List[str]] = None, embedding_model_func: Optional[TextEmbeddingFunction] = None, From d8a6b755eda9afbd63fc17e5a9e80c63a40cbb74 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 24 Sep 2024 21:14:39 +0200 Subject: [PATCH 109/113] Change 'complex' column type to 'json' in LanceDB client Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index f58f0e371b..86ef36c045 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -103,7 +103,7 @@ class LanceDBTypeMapper(TypeMapperImpl): "bigint": pa.int64(), "binary": pa.binary(), "date": pa.date32(), - "complex": pa.string(), + "json": pa.string(), } sct_to_dbt = {} @@ -164,7 +164,7 @@ def from_db_type( if (precision, scale) == self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) - return super().from_db_type(cast(str, db_type), precision, scale) + return super().from_db_type(cast(str, db_type), precision, scale) # type: ignore def write_records( @@ -557,7 +557,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: # normalize property names p_load_id = self.schema.naming.normalize_identifier("load_id") p_dlt_load_id = self.schema.naming.normalize_identifier( - self.schema.data_item_normalizer.C_DLT_LOAD_ID # type: ignore[attr-defined] + self.schema.data_item_normalizer.c_dlt_load_id # type: ignore[attr-defined] ) p_pipeline_name = self.schema.naming.normalize_identifier("pipeline_name") p_status = self.schema.naming.normalize_identifier("status") @@ -707,7 +707,7 @@ def create_table_chain_completed_followup_jobs( completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJobRequest]: jobs = super().create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs + table_chain, completed_table_chain_jobs # type: ignore[arg-type] ) # Orphan removal is only supported for upsert strategy because we need a deterministic key hash. first_table_in_chain = table_chain[0] @@ -774,7 +774,7 @@ def run(self) -> None: ) dlt_id = self._schema.naming.normalize_identifier( - self._schema.data_item_normalizer.C_DLT_ID + self._schema.data_item_normalizer.c_dlt_id # type: ignore[attr-defined] ) write_records( arrow_table, @@ -798,13 +798,13 @@ def __init__( def run(self) -> None: dlt_load_id = self._schema.naming.normalize_identifier( - self._schema.data_item_normalizer.C_DLT_LOAD_ID + self._schema.data_item_normalizer.c_dlt_load_id # type: ignore[attr-defined] ) dlt_id = self._schema.naming.normalize_identifier( - self._schema.data_item_normalizer.C_DLT_ID + self._schema.data_item_normalizer.c_dlt_id # type: ignore[attr-defined] ) dlt_root_id = self._schema.naming.normalize_identifier( - self._schema.data_item_normalizer.C_DLT_ROOT_ID + self._schema.data_item_normalizer.c_dlt_root_id # type: ignore[attr-defined] ) db_client: DBConnection = self._job_client.db_client From a5a1657ab60b272c3a11be65bc42964e8435f7cb Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 24 Sep 2024 21:27:06 +0200 Subject: [PATCH 110/113] update lock file Signed-off-by: Marcel Coetzee --- poetry.lock | 210 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 157 insertions(+), 53 deletions(-) diff --git a/poetry.lock b/poetry.lock index 12c0d75d1e..34f32de996 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "about-time" @@ -2167,32 +2167,33 @@ typing-extensions = ">=3.10.0" [[package]] name = "databricks-sql-connector" -version = "3.3.0" +version = "2.9.6" description = "Databricks SQL Connector for Python" optional = true -python-versions = "<4.0.0,>=3.8.0" +python-versions = "<4.0.0,>=3.7.1" files = [ - {file = "databricks_sql_connector-3.3.0-py3-none-any.whl", hash = "sha256:55ee5a4a11291bf91a235ac76e41b419ddd66a9a321065a8bfaf119acbb26d6b"}, - {file = "databricks_sql_connector-3.3.0.tar.gz", hash = "sha256:19e82965da4c86574adfe9f788c17b4494d98eb8075ba4fd4306573d2edbf194"}, + {file = "databricks_sql_connector-2.9.6-py3-none-any.whl", hash = "sha256:d830abf86e71d2eb83c6a7b7264d6c03926a8a83cec58541ddd6b83d693bde8f"}, + {file = "databricks_sql_connector-2.9.6.tar.gz", hash = "sha256:e55f5b8ede8ae6c6f31416a4cf6352f0ac019bf6875896c668c7574ceaf6e813"}, ] [package.dependencies] +alembic = ">=1.0.11,<2.0.0" lz4 = ">=4.0.2,<5.0.0" numpy = [ - {version = ">=1.16.6,<2.0.0", markers = "python_version >= \"3.8\" and python_version < \"3.11\""}, - {version = ">=1.23.4,<2.0.0", markers = "python_version >= \"3.11\""}, + {version = ">=1.16.6", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=1.23.4", markers = "python_version >= \"3.11\""}, ] oauthlib = ">=3.1.0,<4.0.0" openpyxl = ">=3.0.10,<4.0.0" -pandas = {version = ">=1.2.5,<2.2.0", markers = "python_version >= \"3.8\""} -pyarrow = ">=14.0.1,<17" +pandas = {version = ">=1.2.5,<3.0.0", markers = "python_version >= \"3.8\""} +pyarrow = [ + {version = ">=6.0.0", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=10.0.1", markers = "python_version >= \"3.11\""}, +] requests = ">=2.18.1,<3.0.0" -thrift = ">=0.16.0,<0.21.0" -urllib3 = ">=1.26" - -[package.extras] -alembic = ["alembic (>=1.0.11,<2.0.0)", "sqlalchemy (>=2.0.21)"] -sqlalchemy = ["sqlalchemy (>=2.0.21)"] +sqlalchemy = ">=1.3.24,<2.0.0" +thrift = ">=0.16.0,<0.17.0" +urllib3 = ">=1.0" [[package]] name = "dbt-athena-community" @@ -3788,6 +3789,106 @@ files = [ {file = "google_re2-1.1-4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f4d4f0823e8b2f6952a145295b1ff25245ce9bb136aff6fe86452e507d4c1dd"}, {file = "google_re2-1.1-4-cp39-cp39-win32.whl", hash = "sha256:1afae56b2a07bb48cfcfefaa15ed85bae26a68f5dc7f9e128e6e6ea36914e847"}, {file = "google_re2-1.1-4-cp39-cp39-win_amd64.whl", hash = "sha256:aa7d6d05911ab9c8adbf3c225a7a120ab50fd2784ac48f2f0d140c0b7afc2b55"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:222fc2ee0e40522de0b21ad3bc90ab8983be3bf3cec3d349c80d76c8bb1a4beb"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d4763b0b9195b72132a4e7de8e5a9bf1f05542f442a9115aa27cfc2a8004f581"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:209649da10c9d4a93d8a4d100ecbf9cc3b0252169426bec3e8b4ad7e57d600cf"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:68813aa333c1604a2df4a495b2a6ed065d7c8aebf26cc7e7abb5a6835d08353c"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:370a23ec775ad14e9d1e71474d56f381224dcf3e72b15d8ca7b4ad7dd9cd5853"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:14664a66a3ddf6bc9e56f401bf029db2d169982c53eff3f5876399104df0e9a6"}, + {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ea3722cc4932cbcebd553b69dce1b4a73572823cff4e6a244f1c855da21d511"}, + {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e14bb264c40fd7c627ef5678e295370cd6ba95ca71d835798b6e37502fc4c690"}, + {file = "google_re2-1.1-5-cp310-cp310-win32.whl", hash = "sha256:39512cd0151ea4b3969c992579c79b423018b464624ae955be685fc07d94556c"}, + {file = "google_re2-1.1-5-cp310-cp310-win_amd64.whl", hash = "sha256:ac66537aa3bc5504320d922b73156909e3c2b6da19739c866502f7827b3f9fdf"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5b5ea68d54890c9edb1b930dcb2658819354e5d3f2201f811798bbc0a142c2b4"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:33443511b6b83c35242370908efe2e8e1e7cae749c766b2b247bf30e8616066c"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:413d77bdd5ba0bfcada428b4c146e87707452ec50a4091ec8e8ba1413d7e0619"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:5171686e43304996a34baa2abcee6f28b169806d0e583c16d55e5656b092a414"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3b284db130283771558e31a02d8eb8fb756156ab98ce80035ae2e9e3a5f307c4"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:296e6aed0b169648dc4b870ff47bd34c702a32600adb9926154569ef51033f47"}, + {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:38d50e68ead374160b1e656bbb5d101f0b95fb4cc57f4a5c12100155001480c5"}, + {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a0416a35921e5041758948bcb882456916f22845f66a93bc25070ef7262b72a"}, + {file = "google_re2-1.1-5-cp311-cp311-win32.whl", hash = "sha256:a1d59568bbb5de5dd56dd6cdc79907db26cce63eb4429260300c65f43469e3e7"}, + {file = "google_re2-1.1-5-cp311-cp311-win_amd64.whl", hash = "sha256:72f5a2f179648b8358737b2b493549370debd7d389884a54d331619b285514e3"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:cbc72c45937b1dc5acac3560eb1720007dccca7c9879138ff874c7f6baf96005"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5fadd1417fbef7235fa9453dba4eb102e6e7d94b1e4c99d5fa3dd4e288d0d2ae"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:040f85c63cc02696485b59b187a5ef044abe2f99b92b4fb399de40b7d2904ccc"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:64e3b975ee6d9bbb2420494e41f929c1a0de4bcc16d86619ab7a87f6ea80d6bd"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8ee370413e00f4d828eaed0e83b8af84d7a72e8ee4f4bd5d3078bc741dfc430a"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:5b89383001079323f693ba592d7aad789d7a02e75adb5d3368d92b300f5963fd"}, + {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63cb4fdfbbda16ae31b41a6388ea621510db82feb8217a74bf36552ecfcd50ad"}, + {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ebedd84ae8be10b7a71a16162376fd67a2386fe6361ef88c622dcf7fd679daf"}, + {file = "google_re2-1.1-5-cp312-cp312-win32.whl", hash = "sha256:c8e22d1692bc2c81173330c721aff53e47ffd3c4403ff0cd9d91adfd255dd150"}, + {file = "google_re2-1.1-5-cp312-cp312-win_amd64.whl", hash = "sha256:5197a6af438bb8c4abda0bbe9c4fbd6c27c159855b211098b29d51b73e4cbcf6"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b6727e0b98417e114b92688ad2aa256102ece51f29b743db3d831df53faf1ce3"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:711e2b6417eb579c61a4951029d844f6b95b9b373b213232efd413659889a363"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:71ae8b3df22c5c154c8af0f0e99d234a450ef1644393bc2d7f53fc8c0a1e111c"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:94a04e214bc521a3807c217d50cf099bbdd0c0a80d2d996c0741dbb995b5f49f"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:a770f75358508a9110c81a1257721f70c15d9bb592a2fb5c25ecbd13566e52a5"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:07c9133357f7e0b17c6694d5dcb82e0371f695d7c25faef2ff8117ef375343ff"}, + {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:204ca6b1cf2021548f4a9c29ac015e0a4ab0a7b6582bf2183d838132b60c8fda"}, + {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0b95857c2c654f419ca684ec38c9c3325c24e6ba7d11910a5110775a557bb18"}, + {file = "google_re2-1.1-5-cp38-cp38-win32.whl", hash = "sha256:347ac770e091a0364e822220f8d26ab53e6fdcdeaec635052000845c5a3fb869"}, + {file = "google_re2-1.1-5-cp38-cp38-win_amd64.whl", hash = "sha256:ec32bb6de7ffb112a07d210cf9f797b7600645c2d5910703fa07f456dd2150e0"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb5adf89060f81c5ff26c28e261e6b4997530a923a6093c9726b8dec02a9a326"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a22630c9dd9ceb41ca4316bccba2643a8b1d5c198f21c00ed5b50a94313aaf10"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:544dc17fcc2d43ec05f317366375796351dec44058e1164e03c3f7d050284d58"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:19710af5ea88751c7768575b23765ce0dfef7324d2539de576f75cdc319d6654"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:f82995a205e08ad896f4bd5ce4847c834fab877e1772a44e5f262a647d8a1dec"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:63533c4d58da9dc4bc040250f1f52b089911699f0368e0e6e15f996387a984ed"}, + {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79e00fcf0cb04ea35a22b9014712d448725ce4ddc9f08cc818322566176ca4b0"}, + {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bc41afcefee2da6c4ed883a93d7f527c4b960cd1d26bbb0020a7b8c2d341a60a"}, + {file = "google_re2-1.1-5-cp39-cp39-win32.whl", hash = "sha256:486730b5e1f1c31b0abc6d80abe174ce4f1188fe17d1b50698f2bf79dc6e44be"}, + {file = "google_re2-1.1-5-cp39-cp39-win_amd64.whl", hash = "sha256:4de637ca328f1d23209e80967d1b987d6b352cd01b3a52a84b4d742c69c3da6c"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:621e9c199d1ff0fdb2a068ad450111a84b3bf14f96dfe5a8a7a0deae5f3f4cce"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:220acd31e7dde95373f97c3d1f3b3bd2532b38936af28b1917ee265d25bebbf4"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:db34e1098d164f76251a6ece30e8f0ddfd65bb658619f48613ce71acb3f9cbdb"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:5152bac41d8073977582f06257219541d0fc46ad99b0bbf30e8f60198a43b08c"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:6191294799e373ee1735af91f55abd23b786bdfd270768a690d9d55af9ea1b0d"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:070cbafbb4fecbb02e98feb28a1eb292fb880f434d531f38cc33ee314b521f1f"}, + {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8437d078b405a59a576cbed544490fe041140f64411f2d91012e8ec05ab8bf86"}, + {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f00f9a9af8896040e37896d9b9fc409ad4979f1ddd85bb188694a7d95ddd1164"}, + {file = "google_re2-1.1-6-cp310-cp310-win32.whl", hash = "sha256:df26345f229a898b4fd3cafd5f82259869388cee6268fc35af16a8e2293dd4e5"}, + {file = "google_re2-1.1-6-cp310-cp310-win_amd64.whl", hash = "sha256:3665d08262c57c9b28a5bdeb88632ad792c4e5f417e5645901695ab2624f5059"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b26b869d8aa1d8fe67c42836bf3416bb72f444528ee2431cfb59c0d3e02c6ce3"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:41fd4486c57dea4f222a6bb7f1ff79accf76676a73bdb8da0fcbd5ba73f8da71"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:0ee378e2e74e25960070c338c28192377c4dd41e7f4608f2688064bd2badc41e"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:a00cdbf662693367b36d075b29feb649fd7ee1b617cf84f85f2deebeda25fc64"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:4c09455014217a41499432b8c8f792f25f3df0ea2982203c3a8c8ca0e7895e69"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6501717909185327935c7945e23bb5aa8fc7b6f237b45fe3647fa36148662158"}, + {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3510b04790355f199e7861c29234081900e1e1cbf2d1484da48aa0ba6d7356ab"}, + {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8c0e64c187ca406764f9e9ad6e750d62e69ed8f75bf2e865d0bfbc03b642361c"}, + {file = "google_re2-1.1-6-cp311-cp311-win32.whl", hash = "sha256:2a199132350542b0de0f31acbb3ca87c3a90895d1d6e5235f7792bb0af02e523"}, + {file = "google_re2-1.1-6-cp311-cp311-win_amd64.whl", hash = "sha256:83bdac8ceaece8a6db082ea3a8ba6a99a2a1ee7e9f01a9d6d50f79c6f251a01d"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:81985ff894cd45ab5a73025922ac28c0707759db8171dd2f2cc7a0e856b6b5ad"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5635af26065e6b45456ccbea08674ae2ab62494008d9202df628df3b267bc095"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:813b6f04de79f4a8fdfe05e2cb33e0ccb40fe75d30ba441d519168f9d958bd54"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:5ec2f5332ad4fd232c3f2d6748c2c7845ccb66156a87df73abcc07f895d62ead"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5a687b3b32a6cbb731647393b7c4e3fde244aa557f647df124ff83fb9b93e170"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:39a62f9b3db5d3021a09a47f5b91708b64a0580193e5352751eb0c689e4ad3d7"}, + {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ca0f0b45d4a1709cbf5d21f355e5809ac238f1ee594625a1e5ffa9ff7a09eb2b"}, + {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64b3796a7a616c7861247bd061c9a836b5caf0d5963e5ea8022125601cf7b09"}, + {file = "google_re2-1.1-6-cp312-cp312-win32.whl", hash = "sha256:32783b9cb88469ba4cd9472d459fe4865280a6b1acdad4480a7b5081144c4eb7"}, + {file = "google_re2-1.1-6-cp312-cp312-win_amd64.whl", hash = "sha256:259ff3fd2d39035b9cbcbf375995f83fa5d9e6a0c5b94406ff1cc168ed41d6c6"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:e4711bcffe190acd29104d8ecfea0c0e42b754837de3fb8aad96e6cc3c613cdc"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:4d081cce43f39c2e813fe5990e1e378cbdb579d3f66ded5bade96130269ffd75"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:4f123b54d48450d2d6b14d8fad38e930fb65b5b84f1b022c10f2913bd956f5b5"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:e1928b304a2b591a28eb3175f9db7f17c40c12cf2d4ec2a85fdf1cc9c073ff91"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:3a69f76146166aec1173003c1f547931bdf288c6b135fda0020468492ac4149f"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:fc08c388f4ebbbca345e84a0c56362180d33d11cbe9ccfae663e4db88e13751e"}, + {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b057adf38ce4e616486922f2f47fc7d19c827ba0a7f69d540a3664eba2269325"}, + {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4138c0b933ab099e96f5d8defce4486f7dfd480ecaf7f221f2409f28022ccbc5"}, + {file = "google_re2-1.1-6-cp38-cp38-win32.whl", hash = "sha256:9693e45b37b504634b1abbf1ee979471ac6a70a0035954592af616306ab05dd6"}, + {file = "google_re2-1.1-6-cp38-cp38-win_amd64.whl", hash = "sha256:5674d437baba0ea287a5a7f8f81f24265d6ae8f8c09384e2ef7b6f84b40a7826"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7783137cb2e04f458a530c6d0ee9ef114815c1d48b9102f023998c371a3b060e"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a49b7153935e7a303675f4deb5f5d02ab1305adefc436071348706d147c889e0"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:a96a8bb309182090704593c60bdb369a2756b38fe358bbf0d40ddeb99c71769f"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:dff3d4be9f27ef8ec3705eed54f19ef4ab096f5876c15fe011628c69ba3b561c"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:40f818b0b39e26811fa677978112a8108269977fdab2ba0453ac4363c35d9e66"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:8a7e53538cdb40ef4296017acfbb05cab0c19998be7552db1cfb85ba40b171b9"}, + {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ee18e7569fb714e5bb8c42809bf8160738637a5e71ed5a4797757a1fb4dc4de"}, + {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cda4f6d1a7d5b43ea92bc395f23853fba0caf8b1e1efa6e8c48685f912fcb89"}, + {file = "google_re2-1.1-6-cp39-cp39-win32.whl", hash = "sha256:6a9cdbdc36a2bf24f897be6a6c85125876dc26fea9eb4247234aec0decbdccfd"}, + {file = "google_re2-1.1-6-cp39-cp39-win_amd64.whl", hash = "sha256:73f646cecfad7cc5b4330b4192c25f2e29730a3b8408e089ffd2078094208196"}, ] [[package]] @@ -6737,52 +6838,55 @@ files = [ [[package]] name = "pyarrow" -version = "16.1.0" +version = "17.0.0" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-16.1.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:17e23b9a65a70cc733d8b738baa6ad3722298fa0c81d88f63ff94bf25eaa77b9"}, - {file = "pyarrow-16.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4740cc41e2ba5d641071d0ab5e9ef9b5e6e8c7611351a5cb7c1d175eaf43674a"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:98100e0268d04e0eec47b73f20b39c45b4006f3c4233719c3848aa27a03c1aef"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f68f409e7b283c085f2da014f9ef81e885d90dcd733bd648cfba3ef265961848"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a8914cd176f448e09746037b0c6b3a9d7688cef451ec5735094055116857580c"}, - {file = "pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:48be160782c0556156d91adbdd5a4a7e719f8d407cb46ae3bb4eaee09b3111bd"}, - {file = "pyarrow-16.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:9cf389d444b0f41d9fe1444b70650fea31e9d52cfcb5f818b7888b91b586efff"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:d0ebea336b535b37eee9eee31761813086d33ed06de9ab6fc6aaa0bace7b250c"}, - {file = "pyarrow-16.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e73cfc4a99e796727919c5541c65bb88b973377501e39b9842ea71401ca6c1c"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf9251264247ecfe93e5f5a0cd43b8ae834f1e61d1abca22da55b20c788417f6"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddf5aace92d520d3d2a20031d8b0ec27b4395cab9f74e07cc95edf42a5cc0147"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:25233642583bf658f629eb230b9bb79d9af4d9f9229890b3c878699c82f7d11e"}, - {file = "pyarrow-16.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a33a64576fddfbec0a44112eaf844c20853647ca833e9a647bfae0582b2ff94b"}, - {file = "pyarrow-16.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:185d121b50836379fe012753cf15c4ba9638bda9645183ab36246923875f8d1b"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2e51ca1d6ed7f2e9d5c3c83decf27b0d17bb207a7dea986e8dc3e24f80ff7d6f"}, - {file = "pyarrow-16.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06ebccb6f8cb7357de85f60d5da50e83507954af617d7b05f48af1621d331c9a"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b04707f1979815f5e49824ce52d1dceb46e2f12909a48a6a753fe7cafbc44a0c"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d32000693deff8dc5df444b032b5985a48592c0697cb6e3071a5d59888714e2"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8785bb10d5d6fd5e15d718ee1d1f914fe768bf8b4d1e5e9bf253de8a26cb1628"}, - {file = "pyarrow-16.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e1369af39587b794873b8a307cc6623a3b1194e69399af0efd05bb202195a5a7"}, - {file = "pyarrow-16.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:febde33305f1498f6df85e8020bca496d0e9ebf2093bab9e0f65e2b4ae2b3444"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b5f5705ab977947a43ac83b52ade3b881eb6e95fcc02d76f501d549a210ba77f"}, - {file = "pyarrow-16.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0d27bf89dfc2576f6206e9cd6cf7a107c9c06dc13d53bbc25b0bd4556f19cf5f"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d07de3ee730647a600037bc1d7b7994067ed64d0eba797ac74b2bc77384f4c2"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fbef391b63f708e103df99fbaa3acf9f671d77a183a07546ba2f2c297b361e83"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19741c4dbbbc986d38856ee7ddfdd6a00fc3b0fc2d928795b95410d38bb97d15"}, - {file = "pyarrow-16.1.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:f2c5fb249caa17b94e2b9278b36a05ce03d3180e6da0c4c3b3ce5b2788f30eed"}, - {file = "pyarrow-16.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:e6b6d3cd35fbb93b70ade1336022cc1147b95ec6af7d36906ca7fe432eb09710"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:18da9b76a36a954665ccca8aa6bd9f46c1145f79c0bb8f4f244f5f8e799bca55"}, - {file = "pyarrow-16.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:99f7549779b6e434467d2aa43ab2b7224dd9e41bdde486020bae198978c9e05e"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f07fdffe4fd5b15f5ec15c8b64584868d063bc22b86b46c9695624ca3505b7b4"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddfe389a08ea374972bd4065d5f25d14e36b43ebc22fc75f7b951f24378bf0b5"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b20bd67c94b3a2ea0a749d2a5712fc845a69cb5d52e78e6449bbd295611f3aa"}, - {file = "pyarrow-16.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:ba8ac20693c0bb0bf4b238751d4409e62852004a8cf031c73b0e0962b03e45e3"}, - {file = "pyarrow-16.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:31a1851751433d89a986616015841977e0a188662fcffd1a5677453f1df2de0a"}, - {file = "pyarrow-16.1.0.tar.gz", hash = "sha256:15fbb22ea96d11f0b5768504a3f961edab25eaf4197c341720c4a387f6c60315"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, + {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, + {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, + {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, + {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, + {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, + {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, + {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, + {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, + {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, + {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, + {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, + {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, + {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, + {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, + {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, + {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, ] [package.dependencies] numpy = ">=1.16.6" +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + [[package]] name = "pyasn1" version = "0.5.0" @@ -9829,4 +9933,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "985bb75a9579b44a5f9fd029ade1cc77455b544f2e18f9741b1d0d89bd188537" +content-hash = "6101cae0864d80307ae6d5f33ea263ce8e6d9f86e6e06d317c3d301818aa442e" From 0fc0473662c18d3628a6edba7d9976fe51af7fe2 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 5 Nov 2024 21:46:03 +0100 Subject: [PATCH 111/113] fixes generating lancedb literals --- dlt/common/data_writers/escape.py | 17 +++++++++++++++++ dlt/destinations/impl/lancedb/utils.py | 21 ++++++--------------- tests/load/lancedb/test_utils.py | 16 +++++++++++++++- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 06c8d7a95a..393e9e8508 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -79,6 +79,23 @@ def escape_duckdb_literal(v: Any) -> Any: return str(v) +def escape_lancedb_literal(v: Any) -> Any: + if isinstance(v, str): + # we escape extended string which behave like the redshift string + return _escape_extended(v, prefix="'") + if isinstance(v, (datetime, date, time)): + return f"'{v.isoformat()}'" + if isinstance(v, (list, dict)): + return _escape_extended(json.dumps(v), prefix="'") + # TODO: check how binaries are represented in fusion + if isinstance(v, bytes): + return f"from_base64('{base64.b64encode(v).decode('ascii')}')" + if v is None: + return "NULL" + + return str(v) + + MS_SQL_ESCAPE_DICT = { "'": "''", "\n": "' + CHAR(10) + N'", diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index f07f2754d2..56991b090f 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -4,9 +4,10 @@ import pyarrow as pa from dlt.common import logger +from dlt.common.data_writers.escape import escape_lancedb_literal from dlt.common.destination.exceptions import DestinationTerminalException from dlt.common.schema import TTableSchema -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.schema.utils import get_columns_names_with_prop, get_first_column_name_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" @@ -28,14 +29,8 @@ def set_non_standard_providers_environment_variables( def get_canonical_vector_database_doc_id_merge_key( load_table: TTableSchema, ) -> str: - if merge_key := get_columns_names_with_prop(load_table, "merge_key"): - if len(merge_key) > 1: - raise DestinationTerminalException( - "You cannot specify multiple merge keys with LanceDB orphan remove enabled:" - f" {merge_key}" - ) - else: - return merge_key[0] + if merge_key := get_first_column_name_with_prop(load_table, "merge_key"): + return merge_key elif primary_key := get_columns_names_with_prop(load_table, "primary_key"): # No merge key defined, warn and assume the first element of the primary key is `doc_id`. logger.warning( @@ -74,9 +69,5 @@ def fill_empty_source_column_values_with_placeholder( def create_filter_condition(field_name: str, array: pa.Array) -> str: - def format_value(element: Union[str, int, float, pa.Scalar]) -> str: - if isinstance(element, pa.Scalar): - element = element.as_py() - return "'" + element.replace("'", "''") + "'" if isinstance(element, str) else str(element) - - return f"{field_name} IN ({', '.join(map(format_value, array))})" + array_py = array.to_pylist() + return f"{field_name} IN ({', '.join(map(escape_lancedb_literal, array_py))})" diff --git a/tests/load/lancedb/test_utils.py b/tests/load/lancedb/test_utils.py index 2f517aac8e..d7f9729f26 100644 --- a/tests/load/lancedb/test_utils.py +++ b/tests/load/lancedb/test_utils.py @@ -1,7 +1,10 @@ import pyarrow as pa import pytest -from dlt.destinations.impl.lancedb.utils import fill_empty_source_column_values_with_placeholder +from dlt.destinations.impl.lancedb.utils import ( + create_filter_condition, + fill_empty_source_column_values_with_placeholder, +) # Mark all tests as essential, don't remove. @@ -30,3 +33,14 @@ def test_fill_empty_source_column_values_with_placeholder() -> None: ] expected_table = pa.Table.from_arrays(expected_data, names=["A", "B", "C", "D"]) assert new_table.equals(expected_table) + + +def test_create_filter_condition() -> None: + assert ( + create_filter_condition("_dlt_load_id", pa.array(["A", "B", "C'c\n"])) + == "_dlt_load_id IN ('A', 'B', 'C''c\\n')" + ) + assert ( + create_filter_condition("_dlt_load_id", pa.array([1.2, 3, 5 / 2])) + == "_dlt_load_id IN (1.2, 3.0, 2.5)" + ) From 5efd0a8ca226c7c959e067735af0b0bfae65eb9e Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 5 Nov 2024 21:47:55 +0100 Subject: [PATCH 112/113] verifies merge key early, fixes column override in adapters --- .../impl/lancedb/lancedb_adapter.py | 23 +++-------- .../impl/lancedb/lancedb_client.py | 40 +++++++++++-------- .../impl/qdrant/qdrant_adapter.py | 3 +- .../impl/weaviate/weaviate_adapter.py | 1 + tests/load/lancedb/test_pipeline.py | 8 +++- 5 files changed, 39 insertions(+), 36 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 8f4fbb091d..4314dd703f 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -43,7 +43,7 @@ def lancedb_adapter( resource = get_resource_for_adapter(data) additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} - column_hints: TTableSchemaColumns = {} + column_hints: TTableSchemaColumns = None if embed: if isinstance(embed, str): @@ -52,6 +52,7 @@ def lancedb_adapter( raise ValueError( "'embed' must be a list of column names or a single column name as a string." ) + column_hints = {} for column_name in embed: column_hints[column_name] = { @@ -59,24 +60,12 @@ def lancedb_adapter( VECTORIZE_HINT: True, # type: ignore[misc] } - if merge_key: - if isinstance(merge_key, str): - merge_key = [merge_key] - if not isinstance(merge_key, list): - raise ValueError( - "'merge_key' must be a list of column names or a single column name as a string." - ) - - for column_name in merge_key: - column_hints[column_name] = { - "name": column_name, - "merge_key": True, - } - additional_table_hints[NO_REMOVE_ORPHANS_HINT] = no_remove_orphans - if column_hints or additional_table_hints: - resource.apply_hints(columns=column_hints, additional_table_hints=additional_table_hints) + if column_hints or additional_table_hints or merge_key: + resource.apply_hints( + merge_key=merge_key, columns=column_hints, additional_table_hints=additional_table_hints + ) else: raise ValueError( "You must must provide at least either the 'embed' or 'merge_key' or 'remove_orphans'" diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 86ef36c045..d5ceb1cf4c 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -52,7 +52,7 @@ TColumnSchema, TTableSchema, ) -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.schema.utils import get_columns_names_with_prop, is_nested_table from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -92,7 +92,6 @@ TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} BATCH_PROCESS_CHUNK_SIZE = 10_000 -EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" class LanceDBTypeMapper(TypeMapperImpl): @@ -297,7 +296,7 @@ def create_table( Args: schema: The table schema to create. table_name: The name of the table to create. - mode (): The mode to use when creating the table. Can be either "create" or "overwrite". + mode (str): The mode to use when creating the table. Can be either "create" or "overwrite". By default, if the table already exists, an exception is raised. If you want to overwrite the table, use mode="overwrite". """ @@ -377,6 +376,21 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: def is_storage_initialized(self) -> bool: return self.table_exists(self.sentinel_table) + def verify_schema( + self, only_tables: Iterable[str] = None, new_jobs: Iterable[ParsedLoadJobFileName] = None + ) -> List[PreparedTableSchema]: + loaded_tables = super().verify_schema(only_tables, new_jobs) + # verify merge keys early + for load_table in loaded_tables: + if not is_nested_table(load_table) and not load_table.get(NO_REMOVE_ORPHANS_HINT): + if merge_key := get_columns_names_with_prop(load_table, "merge_key"): + if len(merge_key) > 1: + raise DestinationTerminalException( + "You cannot specify multiple merge keys with LanceDB orphan remove" + f" enabled: {merge_key}" + ) + return loaded_tables + def _create_sentinel_table(self) -> "lancedb.table.Table": """Create an empty table to indicate that the storage is initialized.""" return self.create_table(schema=NULL_SCHEMA, table_name=self.sentinel_table) @@ -405,7 +419,7 @@ def update_stored_schema( # TODO: return a real updated table schema (like in SQL job client) self._execute_schema_update(only_tables) else: - logger.info( + logger.debug( f"Schema with hash {self.schema.stored_version_hash} " f"inserted at {schema_info.inserted_at} found " "in storage, no upgrade required" @@ -567,7 +581,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: p_created_at = self.schema.naming.normalize_identifier("created_at") p_version_hash = self.schema.naming.normalize_identifier("version_hash") - # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less + # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as little # data into memory as possible. state_table = ( state_table_.search() @@ -765,7 +779,7 @@ def run(self) -> None: # Hence, we require at least a primary key on the root table if the merge disposition is chosen. if ( (self._load_table not in self._schema.dlt_table_names()) - and not self._load_table.get("parent") # Is root table. + and not is_nested_table(self._load_table) # Is root table. and (write_disposition == "merge") and (not get_columns_names_with_prop(self._load_table, "primary_key")) ): @@ -797,15 +811,9 @@ def __init__( self.references = ReferenceFollowupJobRequest.resolve_references(file_path) def run(self) -> None: - dlt_load_id = self._schema.naming.normalize_identifier( - self._schema.data_item_normalizer.c_dlt_load_id # type: ignore[attr-defined] - ) - dlt_id = self._schema.naming.normalize_identifier( - self._schema.data_item_normalizer.c_dlt_id # type: ignore[attr-defined] - ) - dlt_root_id = self._schema.naming.normalize_identifier( - self._schema.data_item_normalizer.c_dlt_root_id # type: ignore[attr-defined] - ) + dlt_load_id = self._schema.data_item_normalizer.c_dlt_load_id # type: ignore[attr-defined] + dlt_id = self._schema.data_item_normalizer.c_dlt_id # type: ignore[attr-defined] + dlt_root_id = self._schema.data_item_normalizer.c_dlt_root_id # type: ignore[attr-defined] db_client: DBConnection = self._job_client.db_client table_lineage: TTableLineage = [ @@ -820,7 +828,7 @@ def run(self) -> None: ] for job in table_lineage: - target_is_root_table = "parent" not in job.table_schema + target_is_root_table = not is_nested_table(job.table_schema) fq_table_name = self._job_client.make_qualified_table_name(job.table_name) file_path = job.file_path with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: diff --git a/dlt/destinations/impl/qdrant/qdrant_adapter.py b/dlt/destinations/impl/qdrant/qdrant_adapter.py index abe301fff0..bbc2d719a8 100644 --- a/dlt/destinations/impl/qdrant/qdrant_adapter.py +++ b/dlt/destinations/impl/qdrant/qdrant_adapter.py @@ -34,7 +34,7 @@ def qdrant_adapter( """ resource = get_resource_for_adapter(data) - column_hints: TTableSchemaColumns = {} + column_hints: TTableSchemaColumns = None if embed: if isinstance(embed, str): @@ -44,6 +44,7 @@ def qdrant_adapter( "embed must be a list of column names or a single column name as a string" ) + column_hints = {} for column_name in embed: column_hints[column_name] = { "name": column_name, diff --git a/dlt/destinations/impl/weaviate/weaviate_adapter.py b/dlt/destinations/impl/weaviate/weaviate_adapter.py index 9bd0b41783..0ca9047528 100644 --- a/dlt/destinations/impl/weaviate/weaviate_adapter.py +++ b/dlt/destinations/impl/weaviate/weaviate_adapter.py @@ -87,6 +87,7 @@ def weaviate_adapter( TOKENIZATION_HINT: method, # type: ignore } + # this makes sure that {} as column_hints never gets into apply_hints (that would reset existing columns) if not column_hints: raise ValueError("Either 'vectorize' or 'tokenization' must be specified.") else: diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 7d320ee83c..1549554496 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -56,16 +56,20 @@ def some_data() -> Generator[DictStrStr, Any, None]: lancedb_adapter( some_data, - merge_key=["content"], + merge_key="content", ) + # via merge_key + assert some_data._hints["merge_key"] == "content" + assert some_data.columns["content"] == { # type: ignore "name": "content", "data_type": "text", "x-lancedb-embed": True, - "merge_key": True, } + assert some_data.compute_table_schema()["columns"]["content"]["merge_key"] is True + def test_basic_state_and_schema() -> None: generator_instance1 = sequence_generator() From 4dcc28226ea591cd8d92344e5e930f6c94266d30 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 6 Nov 2024 03:03:09 +0100 Subject: [PATCH 113/113] fixes linting errors --- dlt/destinations/impl/databricks/databricks.py | 2 +- dlt/destinations/impl/databricks/sql_client.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 75bd8ffa13..718427af87 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -224,7 +224,7 @@ def __init__( ) super().__init__(schema, config, sql_client) self.config: DatabricksClientConfiguration = config - self.sql_client: DatabricksSqlClient = sql_client # type: ignore[assignment] + self.sql_client: DatabricksSqlClient = sql_client self.type_mapper = self.capabilities.get_type_mapper() def create_load_job( diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 88d47410d5..8bff4e0d73 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -41,7 +41,7 @@ class DatabricksCursorImpl(DBApiCursorImpl): """Use native data frame support if available""" - native_cursor: DatabricksSqlCursor # type: ignore[assignment] + native_cursor: DatabricksSqlCursor vector_size: ClassVar[int] = 2048 # vector size is 2048 def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: @@ -144,7 +144,7 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB # db_args = kwargs or None db_args = args or kwargs or None - with self._conn.cursor() as curr: # type: ignore[assignment] + with self._conn.cursor() as curr: curr.execute(query, db_args) yield DatabricksCursorImpl(curr) # type: ignore[abstract]