From 68e26a0587e79e9e2cb61d3649086e3108911fff Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 16 Jul 2024 20:56:18 +0200 Subject: [PATCH 01/68] Add tests for LanceDB chunking and merging functionality Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_pipeline.py | 89 ++++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index e817a2f6c8..3f2f1298ba 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -16,7 +16,7 @@ from tests.pipeline.utils import assert_load_info -# Mark all tests as essential, do not remove. +# Mark all tests as essential, don't remove. pytestmark = pytest.mark.essential @@ -426,10 +426,95 @@ def test_empty_dataset_allowed() -> None: client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None - info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) + info = pipe.run( + lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]) + ) # Dataset in load info is empty. assert info.dataset_name is None client = pipe.destination_client() # type: ignore[assignment] assert client.dataset_name is None assert client.sentinel_table == "dltSentinelTable" assert_table(pipe, "content", expected_items_count=3) + + +docs = [ + [ + { + "text": "This is the first document. It contains some text that will be chunked and embedded. (I don't want " + "to be seen in updated run's embedding chunk texts btw)", + "id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "id": 2, + }, + ], + [ + { + "text": "This is the first document, but it has been updated with new content.", + "id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "id": 3, + }, + ], +] + + +def splitter(text: str, chunk_size: int = 10) -> List[str]: + return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)] + + +def test_chunking_no_splitter() -> None: + pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) + info = pipe.run( + docs[0], + table_name="documents", + ) + assert_load_info(info) + + # TODO: Check and compare output + + +def test_chunking_with_splitter() -> None: + pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) + + info = pipe.run( + lancedb_adapter(docs[0], embed="text", splitter=splitter), + table_name="documents", + ) + assert_load_info(info) + + # TODO: Check and compare output + + +def test_chunk_merge() -> None: + """Test chunking is applied without orphaned chunks when new documents arrive.""" + + pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) + + + info = pipe.run( + lancedb_adapter(docs[0], embed="text", splitter=splitter), + table_name="documents", + write_disposition="merge", + primary_key="id", + ) + pipe.run(info) + + # Orphaned chunks must be discarded. + info = pipe.run( + lancedb_adapter(docs[1], embed="text", splitter=splitter), + table_name="documents", + write_disposition="merge", + primary_key="id", + ) + assert_load_info(info) + + # TODO: Check and compare output + + +def test_embedding_provider_only_called_once_per_chunk_hash() -> None: + """Verify that the embedding provider is called only once for each unique chunk hash to optimize API usage and reduce costs.""" + raise NotImplementedError From 900c4faabf346587ad558439600f464cabe33486 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 18 Jul 2024 18:32:58 +0200 Subject: [PATCH 02/68] Add TSplitter type alias for LanceDB document splitting function Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/typing.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 dlt/destinations/impl/lancedb/typing.py diff --git a/dlt/destinations/impl/lancedb/typing.py b/dlt/destinations/impl/lancedb/typing.py new file mode 100644 index 0000000000..c97b23fa7c --- /dev/null +++ b/dlt/destinations/impl/lancedb/typing.py @@ -0,0 +1,4 @@ +from collections.abc import Callable +from typing import List, Any + +TSplitter = Callable[[str, Any], List[str | Any]] From 16230a70506edac7c676cb25cbda2a3078be6e11 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 18 Jul 2024 18:57:11 +0200 Subject: [PATCH 03/68] Refine typing for chunks Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/typing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/typing.py b/dlt/destinations/impl/lancedb/typing.py index c97b23fa7c..bf56955b79 100644 --- a/dlt/destinations/impl/lancedb/typing.py +++ b/dlt/destinations/impl/lancedb/typing.py @@ -1,4 +1,6 @@ -from collections.abc import Callable -from typing import List, Any +from typing import Callable, Union, List, Dict, Any -TSplitter = Callable[[str, Any], List[str | Any]] +ChunkInputT = Union[str, Dict[str, Any], Any] +ChunkOutputT = List[Any] + +TSplitter = Callable[[ChunkInputT, Any], ChunkOutputT] From 3f7a82f972561266d14e3e6c4e6880f1757ac5ff Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 19 Jul 2024 12:18:14 +0200 Subject: [PATCH 04/68] Add type definitions for chunk splitter function and related types Signed-off-by: Marcel Coetzee --- dlt/common/schema/typing.py | 5 +++++ dlt/destinations/impl/lancedb/typing.py | 6 ------ 2 files changed, 5 insertions(+), 6 deletions(-) delete mode 100644 dlt/destinations/impl/lancedb/typing.py diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 9a4dd51d4b..d0ada7a778 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -75,6 +75,11 @@ TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" +ChunkInputT = Union[str, Dict[str, Any], Any] +ChunkOutputT = List[Any] +TSplitter = Callable[[ChunkInputT, Any], ChunkOutputT] +"""Splitter function that takes a ChunkInputT and any additional arguments, returning a ChunkOutputT.""" + # COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) COLUMN_HINTS: Set[TColumnHint] = set( [ diff --git a/dlt/destinations/impl/lancedb/typing.py b/dlt/destinations/impl/lancedb/typing.py deleted file mode 100644 index bf56955b79..0000000000 --- a/dlt/destinations/impl/lancedb/typing.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Callable, Union, List, Dict, Any - -ChunkInputT = Union[str, Dict[str, Any], Any] -ChunkOutputT = List[Any] - -TSplitter = Callable[[ChunkInputT, Any], ChunkOutputT] From 1dda1d517492ff5a25b220f69b960b3e661ef210 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 19 Jul 2024 14:03:20 +0200 Subject: [PATCH 05/68] Remove unused ChunkInputT, ChunkOutputT, and TSplitter type definitions Signed-off-by: Marcel Coetzee --- dlt/common/schema/typing.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index d0ada7a778..9a4dd51d4b 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -75,11 +75,6 @@ TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" -ChunkInputT = Union[str, Dict[str, Any], Any] -ChunkOutputT = List[Any] -TSplitter = Callable[[ChunkInputT, Any], ChunkOutputT] -"""Splitter function that takes a ChunkInputT and any additional arguments, returning a ChunkOutputT.""" - # COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) COLUMN_HINTS: Set[TColumnHint] = set( [ From 48e14ab64f2c7fb57c0ec748d2f3647811dff53e Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 21 Jul 2024 20:43:35 +0200 Subject: [PATCH 06/68] Implement efficient update strategy for chunked documents in LanceDB Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 41 +++++++++++++++---- tests/load/lancedb/test_pipeline.py | 31 +++----------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8265e50fbf..e3025d19bf 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,7 +1,6 @@ import uuid from types import TracebackType from typing import ( - ClassVar, List, Any, cast, @@ -17,6 +16,7 @@ import lancedb # type: ignore import pyarrow as pa +import pyarrow.compute as pc from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore @@ -77,7 +77,6 @@ else: NDArray = ndarray - TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} @@ -145,7 +144,7 @@ def from_db_type( if (precision, scale) == self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) - return super().from_db_type(db_type, precision, scale) + return super().from_db_type(cast(str, db_type), precision, scale) def upload_batch( @@ -154,7 +153,8 @@ def upload_batch( *, db_client: DBConnection, table_name: str, - write_disposition: TWriteDisposition, + parent_table_name: Optional[str] = None, + write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -163,6 +163,7 @@ def upload_batch( records: The data to be inserted as payload. db_client: The LanceDB client connection. table_name: The name of the table to insert into. + parent_table_name: The name of the parent table, if the target table has any. id_field_name: The name of the ID field for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. @@ -190,6 +191,23 @@ def upload_batch( tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) + + # Remove orphaned parent IDs. + if parent_table_name: + try: + parent_tbl = db_client.open_table(parent_table_name) + parent_tbl.checkout_latest() + except FileNotFoundError as e: + raise DestinationTransientException( + "Couldn't open lancedb database. Batch WILL BE RETRIED" + ) from e + + parent_ids = set(pc.unique(parent_tbl.to_arrow()["_dlt_id"]).to_pylist()) + child_ids = set(pc.unique(tbl.to_arrow()["_dlt_parent_id"]).to_pylist()) + + if orphaned_ids := child_ids - parent_ids: + tbl.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") + else: raise DestinationTerminalException( f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" @@ -334,7 +352,7 @@ def drop_storage(self) -> None: Deletes all tables in the dataset and all data, as well as sentinel table associated with them. - If the dataset name was not provided, it deletes all the tables in the current schema. + If the dataset name wasn't provided, it deletes all the tables in the current schema. """ for table_name in self._get_table_names(): self.db_client.drop_table(table_name) @@ -459,9 +477,6 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - embedding_fields: List[str] = get_columns_names_with_prop( - self.schema.get_table(table_name), VECTORIZE_HINT - ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: @@ -524,10 +539,12 @@ def update_schema_in_storage(self) -> None: write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) + upload_batch( records, db_client=self.db_client, table_name=fq_version_table_name, + parent_table_name=None, write_disposition=write_disposition, ) @@ -690,6 +707,8 @@ def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + parent_table = table.get("parent") + return LoadLanceDBJob( self.schema, table, @@ -699,6 +718,9 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> client_config=self.config, model_func=self.model_func, fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) if parent_table else None + ), ) def table_exists(self, table_name: str) -> bool: @@ -718,6 +740,7 @@ def __init__( client_config: LanceDBClientConfiguration, model_func: TextEmbeddingFunction, fq_table_name: str, + fq_parent_table_name: Optional[str], ) -> None: file_name = FileStorage.get_file_name_from_file_path(local_path) super().__init__(file_name) @@ -727,6 +750,7 @@ def __init__( self.type_mapper: TypeMapper = type_mapper self.table_name: str = table_schema["name"] self.fq_table_name: str = fq_table_name + self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func: TextEmbeddingFunction = model_func @@ -759,6 +783,7 @@ def __init__( records, db_client=db_client, table_name=self.fq_table_name, + parent_table_name=self.fq_parent_table_name, write_disposition=self.write_disposition, id_field_name=self.id_field_name, ) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 3f2f1298ba..bc4af48c50 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -426,9 +426,7 @@ def test_empty_dataset_allowed() -> None: client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None - info = pipe.run( - lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]) - ) + info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) # Dataset in load info is empty. assert info.dataset_name is None client = pipe.destination_client() # type: ignore[assignment] @@ -440,8 +438,10 @@ def test_empty_dataset_allowed() -> None: docs = [ [ { - "text": "This is the first document. It contains some text that will be chunked and embedded. (I don't want " - "to be seen in updated run's embedding chunk texts btw)", + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), "id": 1, }, { @@ -462,10 +462,6 @@ def test_empty_dataset_allowed() -> None: ] -def splitter(text: str, chunk_size: int = 10) -> List[str]: - return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)] - - def test_chunking_no_splitter() -> None: pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) info = pipe.run( @@ -477,24 +473,9 @@ def test_chunking_no_splitter() -> None: # TODO: Check and compare output -def test_chunking_with_splitter() -> None: - pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) - - info = pipe.run( - lancedb_adapter(docs[0], embed="text", splitter=splitter), - table_name="documents", - ) - assert_load_info(info) - - # TODO: Check and compare output - - -def test_chunk_merge() -> None: - """Test chunking is applied without orphaned chunks when new documents arrive.""" - +def test_chunk_merge_no_splitter() -> None: pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) - info = pipe.run( lancedb_adapter(docs[0], embed="text", splitter=splitter), table_name="documents", From 32fe174590f6e63edcfae4fde144855692ce2861 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 21 Jul 2024 22:42:58 +0200 Subject: [PATCH 07/68] Implement efficient update strategy for chunked documents in LanceDB Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 17 ++- tests/load/lancedb/test_pipeline.py | 144 +++++++++++------- tests/load/lancedb/utils.py | 25 ++- 3 files changed, 121 insertions(+), 65 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index e3025d19bf..05e327a763 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -141,7 +141,7 @@ def from_db_type( ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale) == self.capabilities.wei_precision: + if (precision, scale)==self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) @@ -183,9 +183,9 @@ def upload_batch( try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition == "replace": + elif write_disposition=="replace": tbl.add(records, mode="overwrite") - elif write_disposition == "merge": + elif write_disposition=="merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( @@ -206,7 +206,10 @@ def upload_batch( child_ids = set(pc.unique(tbl.to_arrow()["_dlt_parent_id"]).to_pylist()) if orphaned_ids := child_ids - parent_ids: - tbl.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") + if len(orphaned_ids) > 1: + tbl.delete(f"_dlt_parent_id IN {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}") + elif len(orphaned_ids) == 1: + tbl.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: raise DestinationTerminalException( @@ -251,7 +254,7 @@ def __init__( self.config.credentials.embedding_model_provider_api_key, ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider == "openai": + if embedding_model_provider=="openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings self.model_func = PatchedOpenAIEmbeddings( @@ -344,7 +347,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name != self.sentinel_table] + return [table_name for table_name in table_names if table_name!=self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -585,7 +588,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" ).sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows == 0: + if joined_table.num_rows==0: return None state = joined_table.take([0]).to_pylist()[0] diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index bc4af48c50..a0776559f8 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,17 +1,21 @@ from typing import Iterator, Generator, Any, List +from typing import Union, Dict import pytest +from lancedb.table import Table import dlt from dlt.common import json -from dlt.common.typing import DictStrStr, DictStrAny -from dlt.common.utils import uniq_id +from dlt.common.typing import DictStrAny +from dlt.common.typing import DictStrStr +from dlt.common.utils import uniq_id, digest128 from dlt.destinations.impl.lancedb.lancedb_adapter import ( lancedb_adapter, VECTORIZE_HINT, ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient -from tests.load.lancedb.utils import assert_table +from dlt.extract import DltResource +from tests.load.lancedb.utils import assert_table, chunk_document, mock_embed from tests.load.utils import sequence_generator, drop_active_pipeline_data from tests.pipeline.utils import assert_load_info @@ -426,7 +430,9 @@ def test_empty_dataset_allowed() -> None: client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None - info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) + info = pipe.run( + lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]) + ) # Dataset in load info is empty. assert info.dataset_name is None client = pipe.destination_client() # type: ignore[assignment] @@ -435,67 +441,97 @@ def test_empty_dataset_allowed() -> None: assert_table(pipe, "content", expected_items_count=3) -docs = [ - [ +def test_merge_no_orphans() -> None: + @dlt.resource( + write_disposition="merge", + merge_key=["doc_id", "chunk_hash"], + table_name="document", + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + chunks = chunk_document(doc["text"]) + embeddings = [ + { + "chunk_hash": digest128(chunk), + "chunk_text": chunk, + "embedding": mock_embed(), + } + for chunk in chunks + ] + yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} + + @dlt.source(max_table_nesting=1) + def documents_source( + docs: List[DictStrAny], + ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: + return documents(docs) + + pipeline = dlt.pipeline( + pipeline_name="chunked_docs", + destination="lancedb", + dataset_name="chunked_documents", + dev_mode=True, + ) + + initial_docs = [ { - "text": ( - "This is the first document. It contains some text that will be chunked and" - " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" - ), - "id": 1, + "text": "This is the first document. It contains some text that will be chunked and embedded. (I don't want " + "to be seen in updated run's embedding chunk texts btw)", + "doc_id": 1, }, { "text": "Here's another document. It's a bit different from the first one.", - "id": 2, + "doc_id": 2, }, - ], - [ + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ { "text": "This is the first document, but it has been updated with new content.", - "id": 1, + "doc_id": 1, }, { "text": "This is a completely new document that wasn't in the initial set.", - "id": 3, + "doc_id": 3, }, - ], -] - - -def test_chunking_no_splitter() -> None: - pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) - info = pipe.run( - docs[0], - table_name="documents", - ) - assert_load_info(info) - - # TODO: Check and compare output - - -def test_chunk_merge_no_splitter() -> None: - pipe = dlt.pipeline(destination="lancedb", dataset_name="docs", dev_mode=True) - - info = pipe.run( - lancedb_adapter(docs[0], embed="text", splitter=splitter), - table_name="documents", - write_disposition="merge", - primary_key="id", - ) - pipe.run(info) + ] - # Orphaned chunks must be discarded. - info = pipe.run( - lancedb_adapter(docs[1], embed="text", splitter=splitter), - table_name="documents", - write_disposition="merge", - primary_key="id", - ) + info = pipeline.run(documents_source(updated_docs)) assert_load_info(info) - # TODO: Check and compare output - - -def test_embedding_provider_only_called_once_per_chunk_hash() -> None: - """Verify that the embedding provider is called only once for each unique chunk hash to optimize API usage and reduce costs.""" - raise NotImplementedError + with pipeline.destination_client() as client: + # Orphaned chunks/documents must have been discarded. + # Shouldn't contain any text from `initial_docs' where doc_id=1. + expected_text = { + "Here's ano", + "ther docum", + "ent. It's ", + "a bit diff", + "erent from", + " the first", + " one.", + "This is th", + "e first do", + "cument, bu", + "t it has b", + "een update", + "d with new", + " content.", + "This is a ", + "completely", + " new docum", + "ent that w", + "asn't in t", + "he initial", + " set.", + } + + embeddings_table_name = client.make_qualified_table_name("document__embeddings") # type: ignore[attr-defined] + + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + assert set(df["chunk_text"]) == expected_text diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index dc3ea5304b..2fdc8e5b40 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -22,8 +22,12 @@ def assert_unordered_dicts_equal( """ assert len(dict_list1) == len(dict_list2), "Lists have different length" - dict_set1 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1} - dict_set2 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2} + dict_set1 = { + tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1 + } + dict_set2 = { + tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2 + } assert dict_set1 == dict_set2, "Lists contain different dictionaries" @@ -40,7 +44,9 @@ def assert_table( exists = client.table_exists(qualified_table_name) assert exists - records = client.db_client.open_table(qualified_table_name).search().limit(50).to_list() + records = ( + client.db_client.open_table(qualified_table_name).search().limit(50).to_list() + ) if expected_items_count is not None: assert expected_items_count == len(records) @@ -52,7 +58,8 @@ def assert_table( "_dlt_id", "_dlt_load_id", dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", - dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) + or "vector__", ] objects_without_dlt_or_special_keys = [ {k: v for k, v in record.items() if k not in drop_keys} for record in records @@ -72,3 +79,13 @@ def generate_embeddings( def ndims(self) -> int: return 2 + + +def mock_embed( + dim: int = 10, +) -> str: + return str(np.random.random_sample(dim)) + + +def chunk_document(doc: str, chunk_size: int = 10) -> List[str]: + return [doc[i : i + chunk_size] for i in range(0, len(doc), chunk_size)] From d97496289d55bfdc3146e4dcef80ad2cadab7f49 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 21 Jul 2024 22:58:18 +0200 Subject: [PATCH 08/68] Refactor LanceDB client and tests for improved readability and type safety Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 17 ++++++++++------- tests/load/lancedb/test_pipeline.py | 12 ++++++------ tests/load/lancedb/utils.py | 15 ++++----------- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 05e327a763..981b062246 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -141,7 +141,7 @@ def from_db_type( ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale)==self.capabilities.wei_precision: + if (precision, scale) == self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) @@ -183,9 +183,9 @@ def upload_batch( try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition=="replace": + elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition=="merge": + elif write_disposition == "merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( @@ -207,7 +207,10 @@ def upload_batch( if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: - tbl.delete(f"_dlt_parent_id IN {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}") + tbl.delete( + "_dlt_parent_id IN" + f" {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}" + ) elif len(orphaned_ids) == 1: tbl.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") @@ -254,7 +257,7 @@ def __init__( self.config.credentials.embedding_model_provider_api_key, ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider=="openai": + if embedding_model_provider == "openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings self.model_func = PatchedOpenAIEmbeddings( @@ -347,7 +350,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name!=self.sentinel_table] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -588,7 +591,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" ).sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows==0: + if joined_table.num_rows == 0: return None state = joined_table.take([0]).to_pylist()[0] diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index a0776559f8..74d97cb88f 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -2,7 +2,7 @@ from typing import Union, Dict import pytest -from lancedb.table import Table +from lancedb.table import Table # type: ignore[import-untyped] import dlt from dlt.common import json @@ -430,9 +430,7 @@ def test_empty_dataset_allowed() -> None: client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None - info = pipe.run( - lancedb_adapter(["context", "created", "not a stop word"], embed=["value"]) - ) + info = pipe.run(lancedb_adapter(["context", "created", "not a stop word"], embed=["value"])) # Dataset in load info is empty. assert info.dataset_name is None client = pipe.destination_client() # type: ignore[assignment] @@ -476,8 +474,10 @@ def documents_source( initial_docs = [ { - "text": "This is the first document. It contains some text that will be chunked and embedded. (I don't want " - "to be seen in updated run's embedding chunk texts btw)", + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), "doc_id": 1, }, { diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index 2fdc8e5b40..8dd56d22aa 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -22,12 +22,8 @@ def assert_unordered_dicts_equal( """ assert len(dict_list1) == len(dict_list2), "Lists have different length" - dict_set1 = { - tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1 - } - dict_set2 = { - tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2 - } + dict_set1 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list1} + dict_set2 = {tuple(sorted((k, v) for k, v in d.items() if v is not None)) for d in dict_list2} assert dict_set1 == dict_set2, "Lists contain different dictionaries" @@ -44,9 +40,7 @@ def assert_table( exists = client.table_exists(qualified_table_name) assert exists - records = ( - client.db_client.open_table(qualified_table_name).search().limit(50).to_list() - ) + records = client.db_client.open_table(qualified_table_name).search().limit(50).to_list() if expected_items_count is not None: assert expected_items_count == len(records) @@ -58,8 +52,7 @@ def assert_table( "_dlt_id", "_dlt_load_id", dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", - dlt.config.get("destination.lancedb.credentials.vector_field_name", str) - or "vector__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", ] objects_without_dlt_or_special_keys = [ {k: v for k, v in record.items() if k not in drop_keys} for record in records From e6cdf5d181343c78090ef86c8cbceb6347ab7c69 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 21 Jul 2024 23:20:06 +0200 Subject: [PATCH 09/68] Linting Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 74d97cb88f..e312a1e591 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -2,7 +2,7 @@ from typing import Union, Dict import pytest -from lancedb.table import Table # type: ignore[import-untyped] +from lancedb.table import Table # type: ignore import dlt from dlt.common import json From a60737aeaa1dc84f6735c7e62afe4ca6d6ac5168 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 27 Jul 2024 17:13:27 +0200 Subject: [PATCH 10/68] Add document_id parameter to lancedb_adapter and update merge logic Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_adapter.py | 20 +++- dlt/destinations/impl/lancedb/utils.py | 4 +- tests/load/lancedb/test_pipeline.py | 113 +++++++++++++++++- 3 files changed, 134 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index bb33632b48..faa2ab3399 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -6,11 +6,13 @@ VECTORIZE_HINT = "x-lancedb-embed" +DOCUMENT_ID_HINT = "x-lancedb-doc-id" def lancedb_adapter( data: Any, embed: TColumnNames = None, + document_id: TColumnNames = None, ) -> DltResource: """Prepares data for the LanceDB destination by specifying which columns should be embedded. @@ -20,6 +22,8 @@ def lancedb_adapter( object. embed (TColumnNames, optional): Specify columns to generate embeddings for. It can be a single column name as a string, or a list of column names. + document_id (TColumnNames, optional): Specify columns which represenet the document + and which will be appended to primary/merge keys. Returns: DltResource: A resource with applied LanceDB-specific hints. @@ -50,8 +54,22 @@ def lancedb_adapter( VECTORIZE_HINT: True, # type: ignore[misc] } + if document_id: + if isinstance(document_id, str): + embed = [document_id] + if not isinstance(document_id, list): + raise ValueError( + "'document_id' must be a list of column names or a single column name as a string." + ) + + for column_name in document_id: + column_hints[column_name] = { + "name": column_name, + DOCUMENT_ID_HINT: True, # type: ignore[misc] + } + if not column_hints: - raise ValueError("A value for 'embed' must be specified.") + raise ValueError("At least one of 'embed' or 'document_id' must be specified.") else: resource.apply_hints(columns=column_hints) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index aeacd4d34b..874852512a 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -6,6 +6,7 @@ from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider +from dlt.destinations.impl.lancedb.lancedb_adapter import DOCUMENT_ID_HINT PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -41,9 +42,10 @@ def list_merge_identifiers(table_schema: TTableSchema) -> Sequence[str]: Sequence[str]: A list of unique column identifiers. """ if table_schema.get("write_disposition") == "merge": + document_id = get_columns_names_with_prop(table_schema, DOCUMENT_ID_HINT) primary_keys = get_columns_names_with_prop(table_schema, "primary_key") merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys)): + if join_keys := list(set(primary_keys + merge_keys + document_id)): return join_keys return get_columns_names_with_prop(table_schema, "unique") diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index e312a1e591..ea52e2f472 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -12,6 +12,7 @@ from dlt.destinations.impl.lancedb.lancedb_adapter import ( lancedb_adapter, VECTORIZE_HINT, + DOCUMENT_ID_HINT, ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient from dlt.extract import DltResource @@ -50,6 +51,18 @@ def some_data() -> Generator[DictStrStr, Any, None]: "x-lancedb-embed": True, } + lancedb_adapter( + some_data, + document_id=["content"], + ) + + assert some_data.columns["content"] == { # type: ignore + "name": "content", + "data_type": "text", + "x-lancedb-embed": True, + "x-lancedb-doc-id": True, + } + def test_basic_state_and_schema() -> None: generator_instance1 = sequence_generator() @@ -442,8 +455,106 @@ def test_empty_dataset_allowed() -> None: def test_merge_no_orphans() -> None: @dlt.resource( write_disposition="merge", - merge_key=["doc_id", "chunk_hash"], + primary_key=["doc_id"], + table_name="document", + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + chunks = chunk_document(doc["text"]) + embeddings = [ + { + "chunk_hash": digest128(chunk), + "chunk_text": chunk, + "embedding": mock_embed(), + } + for chunk in chunks + ] + yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} + + @dlt.source(max_table_nesting=1) + def documents_source( + docs: List[DictStrAny], + ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: + return documents(docs) + + pipeline = dlt.pipeline( + pipeline_name="chunked_docs", + destination="lancedb", + dataset_name="chunked_documents", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ + { + "text": "This is the first document, but it has been updated with new content.", + "doc_id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "doc_id": 3, + }, + ] + + info = pipeline.run(documents_source(updated_docs)) + assert_load_info(info) + + with pipeline.destination_client() as client: + # Orphaned chunks/documents must have been discarded. + # Shouldn't contain any text from `initial_docs' where doc_id=1. + expected_text = { + "Here's ano", + "ther docum", + "ent. It's ", + "a bit diff", + "erent from", + " the first", + " one.", + "This is th", + "e first do", + "cument, bu", + "t it has b", + "een update", + "d with new", + " content.", + "This is a ", + "completely", + " new docum", + "ent that w", + "asn't in t", + "he initial", + " set.", + } + + embeddings_table_name = client.make_qualified_table_name("document__embeddings") # type: ignore[attr-defined] + + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + assert set(df["chunk_text"]) == expected_text + + +def test_merge_no_orphans_with_doc_id() -> None: + @dlt.resource( # type: ignore + write_disposition="merge", table_name="document", + columns={"doc_id": {DOCUMENT_ID_HINT: True}}, ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: From 518a5078ad068a51627888d615e5090b2cf82501 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 29 Jul 2024 16:13:50 +0200 Subject: [PATCH 11/68] Remove resolved comments Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 122 +++++++++++++----- 1 file changed, 89 insertions(+), 33 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 981b062246..09c1cca372 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -78,7 +78,9 @@ NDArray = ndarray TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = { + v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items() +} class LanceDBTypeMapper(TypeMapper): @@ -187,7 +189,9 @@ def upload_batch( tbl.add(records, mode="overwrite") elif write_disposition == "merge": if not id_field_name: - raise ValueError("To perform a merge update, 'id_field_name' must be specified.") + raise ValueError( + "To perform a merge update, 'id_field_name' must be specified." + ) tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) @@ -202,7 +206,9 @@ def upload_batch( "Couldn't open lancedb database. Batch WILL BE RETRIED" ) from e - parent_ids = set(pc.unique(parent_tbl.to_arrow()["_dlt_id"]).to_pylist()) + parent_ids = set( + pc.unique(parent_tbl.to_arrow()["_dlt_id"]).to_pylist() + ) child_ids = set(pc.unique(tbl.to_arrow()["_dlt_parent_id"]).to_pylist()) if orphaned_ids := child_ids - parent_ids: @@ -299,7 +305,9 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + def create_table( + self, table_name: str, schema: TArrowSchema, mode: str = "create" + ) -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -322,7 +330,9 @@ def delete_table(self, table_name: str) -> None: def query_table( self, table_name: str, - query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, + query: Union[ + List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None + ] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. @@ -350,7 +360,11 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name != self.sentinel_table] + return [ + table_name + for table_name in table_names + if table_name != self.sentinel_table + ] @lancedb_error def drop_storage(self) -> None: @@ -403,7 +417,9 @@ def update_stored_schema( applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + schema_info = self.get_stored_schema_by_hash( + self.schema.stored_version_hash + ) except DestinationUndefinedEntity: schema_info = None @@ -458,19 +474,25 @@ def add_table_fields( # Check if any of the new fields already exist in the table. existing_fields = set(arrow_table.schema.names) - new_fields = [field for field in field_schemas if field.name not in existing_fields] + new_fields = [ + field for field in field_schemas if field.name not in existing_fields + ] if not new_fields: # All fields already present, skip. return None - null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + null_arrays = [ + pa.nulls(len(arrow_table), type=field.type) for field in new_fields + ] for field, null_array in zip(new_fields, null_arrays): arrow_table = arrow_table.append_column(field, null_array) try: - return self.db_client.create_table(table_name, arrow_table, mode="overwrite") + return self.db_client.create_table( + table_name, arrow_table, mode="overwrite" + ) except OSError: # Error occurred while creating the table, skip. return None @@ -483,11 +505,15 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") + logger.info( + f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" + ) if len(new_columns) > 0: if exists: field_schemas: List[TArrowField] = [ - make_arrow_field_schema(column["name"], column, self.type_mapper) + make_arrow_field_schema( + column["name"], column, self.type_mapper + ) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) @@ -500,7 +526,9 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func - embedding_model_dimensions = self.config.embedding_model_dimensions + embedding_model_dimensions = ( + self.config.embedding_model_dimensions + ) else: embedding_fields = None vector_field_name = None @@ -531,8 +559,12 @@ def update_schema_in_storage(self) -> None: self.schema.naming.normalize_identifier( "engine_version" ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("inserted_at"): str( + pendulum.now() + ), + self.schema.naming.normalize_identifier( + "schema_name" + ): self.schema.name, self.schema.naming.normalize_identifier( "version_hash" ): self.schema.stored_version_hash, @@ -541,7 +573,9 @@ def update_schema_in_storage(self) -> None: ), } ] - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) @@ -557,8 +591,12 @@ def update_schema_in_storage(self) -> None: @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Retrieves the latest completed state for a pipeline.""" - fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_state_table_name = self.make_qualified_table_name( + self.schema.state_table_name + ) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) state_table_: Table = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() @@ -584,7 +622,9 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) .to_arrow() ) - loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + loads_table = ( + loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + ) # Join arrow tables in-memory. joined_table: pa.Table = state_table.join( @@ -606,8 +646,12 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) @lancedb_error - def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + def get_stored_schema_by_hash( + self, schema_hash: str + ) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -625,9 +669,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI ) ).to_list() - # LanceDB's ORDER BY clause doesn't seem to work. - # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 - most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + most_recent_schema = sorted( + schemas, key=lambda x: x[p_inserted_at], reverse=True + )[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -642,7 +686,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI @lancedb_error def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -660,9 +706,9 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: ) ).to_list() - # LanceDB's ORDER BY clause doesn't seem to work. - # See https://github.com/dlt-hub/dlt/pull/1375#issuecomment-2171909341 - most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + most_recent_schema = sorted( + schemas, key=lambda x: x[p_inserted_at], reverse=True + )[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -690,15 +736,21 @@ def complete_load(self, load_id: str) -> None: records = [ { self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier( + "schema_name" + ): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("inserted_at"): str( + pendulum.now() + ), self.schema.naming.normalize_identifier( "schema_version_hash" ): None, # Payload schema must match the target schema. } ] - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) @@ -712,7 +764,9 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: parent_table = table.get("parent") return LoadLanceDBJob( @@ -758,7 +812,9 @@ def __init__( self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) + self.embedding_fields: List[str] = get_columns_names_with_prop( + table_schema, VECTORIZE_HINT + ) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name From c10bd734f8742b4b6a910c18d52c10f9279cedcc Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 29 Jul 2024 17:03:33 +0200 Subject: [PATCH 12/68] Implement efficient orphan removal for chunked documents in LanceDB Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 104 +++++++++++++----- 1 file changed, 74 insertions(+), 30 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 09c1cca372..4e3d5c7488 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -38,6 +38,7 @@ StorageSchemaInfo, StateInfo, TLoadJobState, + FollowupJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -155,7 +156,6 @@ def upload_batch( *, db_client: DBConnection, table_name: str, - parent_table_name: Optional[str] = None, write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, ) -> None: @@ -165,7 +165,6 @@ def upload_batch( records: The data to be inserted as payload. db_client: The LanceDB client connection. table_name: The name of the table to insert into. - parent_table_name: The name of the parent table, if the target table has any. id_field_name: The name of the ID field for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. @@ -187,7 +186,7 @@ def upload_batch( tbl.add(records) elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition == "merge": + elif write_disposition in ("merge" or "upsert"): if not id_field_name: raise ValueError( "To perform a merge update, 'id_field_name' must be specified." @@ -195,31 +194,6 @@ def upload_batch( tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) - - # Remove orphaned parent IDs. - if parent_table_name: - try: - parent_tbl = db_client.open_table(parent_table_name) - parent_tbl.checkout_latest() - except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Batch WILL BE RETRIED" - ) from e - - parent_ids = set( - pc.unique(parent_tbl.to_arrow()["_dlt_id"]).to_pylist() - ) - child_ids = set(pc.unique(tbl.to_arrow()["_dlt_parent_id"]).to_pylist()) - - if orphaned_ids := child_ids - parent_ids: - if len(orphaned_ids) > 1: - tbl.delete( - "_dlt_parent_id IN" - f" {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}" - ) - elif len(orphaned_ids) == 1: - tbl.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") - else: raise DestinationTerminalException( f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" @@ -584,7 +558,6 @@ def update_schema_in_storage(self) -> None: records, db_client=self.db_client, table_name=fq_version_table_name, - parent_table_name=None, write_disposition=write_disposition, ) @@ -845,7 +818,6 @@ def __init__( records, db_client=db_client, table_name=self.fq_table_name, - parent_table_name=self.fq_parent_table_name, write_disposition=self.write_disposition, id_field_name=self.id_field_name, ) @@ -855,3 +827,75 @@ def state(self) -> TLoadJobState: def exception(self) -> str: raise NotImplementedError() + + +class LanceDBRemoveOrphansJob(LoadJob, FollowupJob): + def __init__( + self, + db_client: DBConnection, + file_name: str, + table_schema: TTableSchema, + fq_table_name: str, + fq_parent_table_name: Optional[str], + ) -> None: + super().__init__(file_name) + self.db_client = db_client + self._state: TLoadJobState = "running" + self.table_schema: TTableSchema = table_schema + self.fq_table_name: str = fq_table_name + self.fq_parent_table_name: Optional[str] = fq_parent_table_name + self.write_disposition: TWriteDisposition = cast( + TWriteDisposition, self.table_schema.get("write_disposition", "append") + ) + + def execute(self) -> None: + if self.write_disposition not in ("merge" or "upsert"): + raise DestinationTerminalException( + f"Unsupported write disposition {self.write_disposition} for LanceDB Destination Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." + ) + + try: + child_table = self.db_client.open_table(self.fq_table_name) + child_table.checkout_latest() + if self.fq_parent_table_name: + parent_table = self.db_client.open_table(self.fq_parent_table_name) + parent_table.checkout_latest() + except FileNotFoundError as e: + raise DestinationTransientException( + "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" + ) from e + + try: + # Chunks in child table. + if self.fq_parent_table_name: + parent_ids = set( + pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist() + ) + child_ids = set( + pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist() + ) + + if orphaned_ids := child_ids - parent_ids: + if len(orphaned_ids) > 1: + child_table.delete( + "_dlt_parent_id IN" + f" {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}" + ) + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") + + # Chunks in root table. TODO: Add test for embeddings in root table + else: + ... + except ArrowInvalid as e: + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e + + self._state = "completed" + + def state(self) -> TLoadJobState: + return self._state + + def exception(self) -> str: + raise NotImplementedError() From 5b3acb167cec0fa863124191490766fdfe5c4197 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 19:30:16 +0200 Subject: [PATCH 13/68] Implement efficient update strategy for chunked documents in LanceDB Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 57 ++++++++++++++----- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 4e3d5c7488..a8ea9f6f9d 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -38,7 +38,7 @@ StorageSchemaInfo, StateInfo, TLoadJobState, - FollowupJob, + NewLoadJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -49,7 +49,7 @@ TWriteDisposition, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage +from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -70,7 +70,7 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -756,6 +756,33 @@ def start_file_load( ), ) + def create_table_chain_completed_followup_jobs( + self, + table_chain: Sequence[TTableSchema], + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + ) -> List[NewLoadJob]: + assert completed_table_chain_jobs is not None + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) + + for table in table_chain: + parent_table = table.get("parent") + jobs.append( + LanceDBRemoveOrphansJob( + db_client=self.db_client, + table_schema=self.prepare_load_table(table["name"]), + fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) + if parent_table + else None + ), + ) + ) + + return jobs + def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -829,18 +856,24 @@ def exception(self) -> str: raise NotImplementedError() -class LanceDBRemoveOrphansJob(LoadJob, FollowupJob): +class LanceDBRemoveOrphansJob(NewReferenceJob): def __init__( self, db_client: DBConnection, - file_name: str, table_schema: TTableSchema, fq_table_name: str, fq_parent_table_name: Optional[str], ) -> None: - super().__init__(file_name) self.db_client = db_client - self._state: TLoadJobState = "running" + + ref_file_name = ParsedLoadJobFileName( + table_schema["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" + ).file_name() + super().__init__( + file_name=ref_file_name, + status="running", + ) + self.table_schema: TTableSchema = table_schema self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name @@ -848,6 +881,8 @@ def __init__( TWriteDisposition, self.table_schema.get("write_disposition", "append") ) + self.execute() + def execute(self) -> None: if self.write_disposition not in ("merge" or "upsert"): raise DestinationTerminalException( @@ -885,6 +920,7 @@ def execute(self) -> None: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") # Chunks in root table. TODO: Add test for embeddings in root table + # TODO: Add unit tests with simple data else: ... except ArrowInvalid as e: @@ -892,10 +928,5 @@ def execute(self) -> None: "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e - self._state = "completed" - def state(self) -> TLoadJobState: - return self._state - - def exception(self) -> str: - raise NotImplementedError() + return "completed" From cf6d86a965d9323da1dfe7b7144cacf640ef2de0 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 21:37:47 +0200 Subject: [PATCH 14/68] Add test for removing orphaned records in LanceDB Signed-off-by: Marcel Coetzee --- .../lancedb/test_remove_orphaned_records.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 tests/load/lancedb/test_remove_orphaned_records.py diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py new file mode 100644 index 0000000000..9f5e3b288c --- /dev/null +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -0,0 +1,84 @@ +from typing import Iterator, List, Generator + +import pytest + +import dlt +from dlt.common.schema.typing import TLoaderMergeStrategy +from dlt.common.typing import DictStrAny +from dlt.common.utils import uniq_id +from tests.load.utils import ( + drop_active_pipeline_data, +) +from tests.pipeline.utils import ( + assert_load_info, + load_table_counts, + load_tables_to_dicts, +) + + +# Mark all tests as essential, don't remove. +pytestmark = pytest.mark.essential + + +@pytest.fixture(autouse=True) +def drop_lancedb_data() -> Iterator[None]: + yield + drop_active_pipeline_data() + + +@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) +def test_lancedb_remove_orphaned_records( + merge_strategy: TLoaderMergeStrategy, +) -> None: + pipeline = dlt.pipeline( + pipeline_name="test_pipeline_append", + destination="lancedb", + dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + ) + + @dlt.resource( + table_name="parent", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, + primary_key="id", + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + run_1 = [ + {"id": 1, "child": [{"bar": 1}, {"bar": 2}]}, + {"id": 2, "child": [{"bar": 3}]}, + {"id": 3, "child": [{"bar": 10}, {"bar": 11}]}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + counts = load_table_counts(pipeline, "parent", "parent__child") + assert counts["parent"] == 2 + assert counts["parent__child"] == 4 + + run_2 = [ + {"id": 1, "child": [{"bar": 1}]}, # Removed one child. + {"id": 2, "child": [{"bar": 4}, {"baz": 1}]}, # Changed child. + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + # Check whether orphaned child records were removed. + counts = load_table_counts(pipeline, "parent", "parent__child") + assert counts["parent"] == 2 + assert counts["parent__child"] == 2 + + child_data = load_tables_to_dicts(pipeline, "parent__child") + expected_child_data = [ + {"bar": 1}, + {"bar": 4}, + {"baz": 1}, + {"bar": 10}, + {"bar": 11}, + ] + assert ( + sorted(child_data["parent__child"], key=lambda x: x["bar"]) + == expected_child_data + ) From d338586b115ef6e5fbd1a4578fe6d374b42be714 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 22:41:48 +0200 Subject: [PATCH 15/68] Update LanceDB orphaned records removal test for chunked documents Signed-off-by: Marcel Coetzee --- .../lancedb/test_remove_orphaned_records.py | 36 +++++++------------ 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 9f5e3b288c..5e19efba09 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -11,8 +11,6 @@ ) from tests.pipeline.utils import ( assert_load_info, - load_table_counts, - load_tables_to_dicts, ) @@ -54,31 +52,23 @@ def identity_resource( info = pipeline.run(identity_resource(run_1)) assert_load_info(info) - counts = load_table_counts(pipeline, "parent", "parent__child") - assert counts["parent"] == 2 - assert counts["parent__child"] == 4 - run_2 = [ {"id": 1, "child": [{"bar": 1}]}, # Removed one child. - {"id": 2, "child": [{"bar": 4}, {"baz": 1}]}, # Changed child. + {"id": 2, "child": [{"bar": 4}]}, # Changed child. ] info = pipeline.run(identity_resource(run_2)) assert_load_info(info) - # Check whether orphaned child records were removed. - counts = load_table_counts(pipeline, "parent", "parent__child") - assert counts["parent"] == 2 - assert counts["parent__child"] == 2 + with pipeline.destination_client() as client: + expected_child_data = [ + 1, + 4, + 10, + 11, + ] - child_data = load_tables_to_dicts(pipeline, "parent__child") - expected_child_data = [ - {"bar": 1}, - {"bar": 4}, - {"baz": 1}, - {"bar": 10}, - {"bar": 11}, - ] - assert ( - sorted(child_data["parent__child"], key=lambda x: x["bar"]) - == expected_child_data - ) + embeddings_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] + + tbl = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + assert sorted(df["bar"].to_list()) == expected_child_data From 2376c6a445b67d8949285bacdd9e4f619e5af463 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 22:49:22 +0200 Subject: [PATCH 16/68] Set test pipeline as dev mode Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_remove_orphaned_records.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 5e19efba09..f367d31807 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -32,6 +32,7 @@ def test_lancedb_remove_orphaned_records( pipeline_name="test_pipeline_append", destination="lancedb", dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + dev_mode=True, ) @dlt.resource( From 7f6f1cd0e0b3eb9a3cd1b6fcc3d46f9a63e75c4d Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 30 Jul 2024 23:56:16 +0200 Subject: [PATCH 17/68] Fix write disposition check in LanceDBRemoveOrphansJob execute method Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a8ea9f6f9d..a6dad09b60 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -186,7 +186,7 @@ def upload_batch( tbl.add(records) elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition in ("merge" or "upsert"): + elif write_disposition == "merge": if not id_field_name: raise ValueError( "To perform a merge update, 'id_field_name' must be specified." @@ -884,7 +884,7 @@ def __init__( self.execute() def execute(self) -> None: - if self.write_disposition not in ("merge" or "upsert"): + if self.write_disposition != "merge": raise DestinationTerminalException( f"Unsupported write disposition {self.write_disposition} for LanceDB Destination Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." ) From c276211b91cfc3735f0d0a36a83e0715ab4c3313 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 16:07:14 +0200 Subject: [PATCH 18/68] Add FollowupJob trait to LoadLanceDBJob Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a6dad09b60..5beff057dd 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -39,6 +39,7 @@ StateInfo, TLoadJobState, NewLoadJob, + FollowupJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -70,7 +71,7 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob +from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob, NewLoadJobImpl from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -787,7 +788,7 @@ def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LoadLanceDBJob(LoadJob): +class LoadLanceDBJob(LoadJob, FollowupJob): arrow_schema: TArrowSchema def __init__( @@ -856,7 +857,7 @@ def exception(self) -> str: raise NotImplementedError() -class LanceDBRemoveOrphansJob(NewReferenceJob): +class LanceDBRemoveOrphansJob(NewLoadJobImpl): def __init__( self, db_client: DBConnection, @@ -920,7 +921,6 @@ def execute(self) -> None: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") # Chunks in root table. TODO: Add test for embeddings in root table - # TODO: Add unit tests with simple data else: ... except ArrowInvalid as e: From dbfd5af9c28bed805a5364d9fa954d2e8abf6210 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 17:48:04 +0200 Subject: [PATCH 19/68] Fix file type Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 52 +++++++++++-------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 5beff057dd..cb8935e672 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -71,7 +71,7 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob, NewLoadJobImpl +from dlt.destinations.job_impl import EmptyLoadJob, NewLoadJobImpl from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -768,19 +768,21 @@ def create_table_chain_completed_followup_jobs( ) for table in table_chain: - parent_table = table.get("parent") - jobs.append( - LanceDBRemoveOrphansJob( - db_client=self.db_client, - table_schema=self.prepare_load_table(table["name"]), - fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) - if parent_table - else None - ), + # Only tables with merge disposition are dispatched for orphan removal jobs. + if table.get("write_disposition", "append") == "merge": + parent_table = table.get("parent") + jobs.append( + LanceDBRemoveOrphansJob( + db_client=self.db_client, + table_schema=self.prepare_load_table(table["name"]), + fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) + if parent_table + else None + ), + ) ) - ) return jobs @@ -866,15 +868,6 @@ def __init__( fq_parent_table_name: Optional[str], ) -> None: self.db_client = db_client - - ref_file_name = ParsedLoadJobFileName( - table_schema["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" - ).file_name() - super().__init__( - file_name=ref_file_name, - status="running", - ) - self.table_schema: TTableSchema = table_schema self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name @@ -882,6 +875,20 @@ def __init__( TWriteDisposition, self.table_schema.get("write_disposition", "append") ) + job_id = ParsedLoadJobFileName( + table_schema["name"], + ParsedLoadJobFileName.new_file_id(), + 0, + "parquet", + ).file_name() + + super().__init__( + file_name=job_id, + status="running", + ) + + self._save_text_file("") + self.execute() def execute(self) -> None: @@ -890,6 +897,7 @@ def execute(self) -> None: f"Unsupported write disposition {self.write_disposition} for LanceDB Destination Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." ) + # Orphans are removed irrespective of which merge strategy is picked. try: child_table = self.db_client.open_table(self.fq_table_name) child_table.checkout_latest() From 257fbde1c809211e4dfce4962f40807b59ad71b6 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 17:54:55 +0200 Subject: [PATCH 20/68] Fix file typing Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 136 +++++------------- .../lancedb/test_remove_orphaned_records.py | 8 +- 2 files changed, 38 insertions(+), 106 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index cb8935e672..40da46f199 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -80,9 +80,7 @@ NDArray = ndarray TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = { - v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items() -} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} class LanceDBTypeMapper(TypeMapper): @@ -189,9 +187,7 @@ def upload_batch( tbl.add(records, mode="overwrite") elif write_disposition == "merge": if not id_field_name: - raise ValueError( - "To perform a merge update, 'id_field_name' must be specified." - ) + raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) @@ -280,9 +276,7 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table( - self, table_name: str, schema: TArrowSchema, mode: str = "create" - ) -> Table: + def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -305,9 +299,7 @@ def delete_table(self, table_name: str) -> None: def query_table( self, table_name: str, - query: Union[ - List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None - ] = None, + query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. @@ -335,11 +327,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [ - table_name - for table_name in table_names - if table_name != self.sentinel_table - ] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -392,9 +380,7 @@ def update_stored_schema( applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash( - self.schema.stored_version_hash - ) + schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) except DestinationUndefinedEntity: schema_info = None @@ -449,25 +435,19 @@ def add_table_fields( # Check if any of the new fields already exist in the table. existing_fields = set(arrow_table.schema.names) - new_fields = [ - field for field in field_schemas if field.name not in existing_fields - ] + new_fields = [field for field in field_schemas if field.name not in existing_fields] if not new_fields: # All fields already present, skip. return None - null_arrays = [ - pa.nulls(len(arrow_table), type=field.type) for field in new_fields - ] + null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] for field, null_array in zip(new_fields, null_arrays): arrow_table = arrow_table.append_column(field, null_array) try: - return self.db_client.create_table( - table_name, arrow_table, mode="overwrite" - ) + return self.db_client.create_table(table_name, arrow_table, mode="overwrite") except OSError: # Error occurred while creating the table, skip. return None @@ -480,15 +460,11 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - logger.info( - f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" - ) + logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: field_schemas: List[TArrowField] = [ - make_arrow_field_schema( - column["name"], column, self.type_mapper - ) + make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) @@ -501,9 +477,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func - embedding_model_dimensions = ( - self.config.embedding_model_dimensions - ) + embedding_model_dimensions = self.config.embedding_model_dimensions else: embedding_fields = None vector_field_name = None @@ -534,12 +508,8 @@ def update_schema_in_storage(self) -> None: self.schema.naming.normalize_identifier( "engine_version" ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str( - pendulum.now() - ), - self.schema.naming.normalize_identifier( - "schema_name" - ): self.schema.name, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier( "version_hash" ): self.schema.stored_version_hash, @@ -548,9 +518,7 @@ def update_schema_in_storage(self) -> None: ), } ] - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) @@ -565,12 +533,8 @@ def update_schema_in_storage(self) -> None: @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Retrieves the latest completed state for a pipeline.""" - fq_state_table_name = self.make_qualified_table_name( - self.schema.state_table_name - ) - fq_loads_table_name = self.make_qualified_table_name( - self.schema.loads_table_name - ) + fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) state_table_: Table = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() @@ -596,9 +560,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) .to_arrow() ) - loads_table = ( - loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() - ) + loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() # Join arrow tables in-memory. joined_table: pa.Table = state_table.join( @@ -620,12 +582,8 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) @lancedb_error - def get_stored_schema_by_hash( - self, schema_hash: str - ) -> Optional[StorageSchemaInfo]: - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -643,9 +601,7 @@ def get_stored_schema_by_hash( ) ).to_list() - most_recent_schema = sorted( - schemas, key=lambda x: x[p_inserted_at], reverse=True - )[0] + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -660,9 +616,7 @@ def get_stored_schema_by_hash( @lancedb_error def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -680,9 +634,7 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: ) ).to_list() - most_recent_schema = sorted( - schemas, key=lambda x: x[p_inserted_at], reverse=True - )[0] + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -710,21 +662,15 @@ def complete_load(self, load_id: str) -> None: records = [ { self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier( - "schema_name" - ): self.schema.name, + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str( - pendulum.now() - ), + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), self.schema.naming.normalize_identifier( "schema_version_hash" ): None, # Payload schema must match the target schema. } ] - fq_loads_table_name = self.make_qualified_table_name( - self.schema.loads_table_name - ) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) @@ -738,9 +684,7 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load( - self, table: TTableSchema, file_path: str, load_id: str - ) -> LoadJob: + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: parent_table = table.get("parent") return LoadLanceDBJob( @@ -769,7 +713,7 @@ def create_table_chain_completed_followup_jobs( for table in table_chain: # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition", "append") == "merge": + if table.get("write_disposition") == "merge": parent_table = table.get("parent") jobs.append( LanceDBRemoveOrphansJob( @@ -777,9 +721,7 @@ def create_table_chain_completed_followup_jobs( table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name(table["name"]), fq_parent_table_name=( - self.make_qualified_table_name(parent_table) - if parent_table - else None + self.make_qualified_table_name(parent_table) if parent_table else None ), ) ) @@ -815,9 +757,7 @@ def __init__( self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop( - table_schema, VECTORIZE_HINT - ) + self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name @@ -872,7 +812,7 @@ def __init__( self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition", "append") + TWriteDisposition, self.table_schema.get("write_disposition") ) job_id = ParsedLoadJobFileName( @@ -894,7 +834,8 @@ def __init__( def execute(self) -> None: if self.write_disposition != "merge": raise DestinationTerminalException( - f"Unsupported write disposition {self.write_disposition} for LanceDB Destination Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." + f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" + " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." ) # Orphans are removed irrespective of which merge strategy is picked. @@ -912,12 +853,8 @@ def execute(self) -> None: try: # Chunks in child table. if self.fq_parent_table_name: - parent_ids = set( - pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist() - ) - child_ids = set( - pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist() - ) + parent_ids = set(pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist()) + child_ids = set(pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist()) if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: @@ -929,8 +866,7 @@ def execute(self) -> None: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") # Chunks in root table. TODO: Add test for embeddings in root table - else: - ... + else: ... except ArrowInvalid as e: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index f367d31807..c5b60a5f3f 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -3,7 +3,6 @@ import pytest import dlt -from dlt.common.schema.typing import TLoaderMergeStrategy from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id from tests.load.utils import ( @@ -24,10 +23,7 @@ def drop_lancedb_data() -> Iterator[None]: drop_active_pipeline_data() -@pytest.mark.parametrize("merge_strategy", ("delete-insert", "upsert")) -def test_lancedb_remove_orphaned_records( - merge_strategy: TLoaderMergeStrategy, -) -> None: +def test_lancedb_remove_orphaned_records() -> None: pipeline = dlt.pipeline( pipeline_name="test_pipeline_append", destination="lancedb", @@ -37,7 +33,7 @@ def test_lancedb_remove_orphaned_records( @dlt.resource( table_name="parent", - write_disposition={"disposition": "merge", "strategy": merge_strategy}, + write_disposition="merge", primary_key="id", ) def identity_resource( From 0502ddf01c7cf3b190f9352b1f0fbe8d2e59f13d Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 20:25:55 +0200 Subject: [PATCH 21/68] Add test for removing orphaned records in LanceDB root table Signed-off-by: Marcel Coetzee --- .../lancedb/test_remove_orphaned_records.py | 89 ++++++++++++++++--- 1 file changed, 79 insertions(+), 10 deletions(-) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index c5b60a5f3f..01481d83e3 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -1,6 +1,9 @@ from typing import Iterator, List, Generator +import pandas as pd import pytest +from pandas import DataFrame +from pandas.testing import assert_frame_equal import dlt from dlt.common.typing import DictStrAny @@ -25,9 +28,9 @@ def drop_lancedb_data() -> Iterator[None]: def test_lancedb_remove_orphaned_records() -> None: pipeline = dlt.pipeline( - pipeline_name="test_pipeline_append", + pipeline_name="test_lancedb_remove_orphaned_records", destination="lancedb", - dataset_name=f"TestPipelineAppendDataset{uniq_id()}", + dataset_name=f"test_lancedb_remove_orphaned_records_{uniq_id()}", dev_mode=True, ) @@ -57,15 +60,81 @@ def identity_resource( assert_load_info(info) with pipeline.destination_client() as client: - expected_child_data = [ - 1, - 4, - 10, - 11, - ] + expected_child_data = pd.DataFrame( + data=[ + {"bar": 1}, + {"bar": 4}, + {"bar": 10}, + {"bar": 11}, + ] + ) embeddings_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] tbl = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] - df = tbl.to_pandas() - assert sorted(df["bar"].to_list()) == expected_child_data + actual_df = tbl.to_pandas() + + expected_child_data = expected_child_data.sort_values(by="bar") + actual_df = actual_df.sort_values(by="bar").reset_index(drop=True) + + assert_frame_equal(actual_df[["bar"]], expected_child_data) + + +def test_lancedb_remove_orphaned_records_root_table() -> None: + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_root_table", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_root_table_{uniq_id()}", + dev_mode=True, + ) + + @dlt.resource( + table_name="root", + write_disposition="merge", + primary_key="doc_id", + ) + def identity_resource( + data: List[DictStrAny], + ) -> Generator[List[DictStrAny], None, None]: + yield data + + run_1 = [ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2a"}, + {"doc_id": 2, "chunk_hash": "2b"}, + {"doc_id": 2, "chunk_hash": "2c"}, + {"doc_id": 3, "chunk_hash": "3a"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_1)) + assert_load_info(info) + + run_2 = [ + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + info = pipeline.run(identity_resource(run_2)) + assert_load_info(info) + + with pipeline.destination_client() as client: + expected_root_table_df = pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + ).sort_values(by=["doc_id", "chunk_hash"]) + + root_table_name = ( + client.make_qualified_table_name("root") # type: ignore[attr-defined] + ) + tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] + + actual_root_df: DataFrame = ( + tbl.to_pandas() + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) + assert_frame_equal(actual_root_df, expected_root_table_df) From 2363b51a38c085f6e9f81b1002a2e79021cb0724 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 31 Jul 2024 20:48:16 +0200 Subject: [PATCH 22/68] Enhance LanceDB test to cover nested child removal and update scenarios Signed-off-by: Marcel Coetzee --- .../lancedb/test_remove_orphaned_records.py | 93 +++++++++++++++---- 1 file changed, 74 insertions(+), 19 deletions(-) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 01481d83e3..9104399c9b 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -45,16 +45,34 @@ def identity_resource( yield data run_1 = [ - {"id": 1, "child": [{"bar": 1}, {"bar": 2}]}, - {"id": 2, "child": [{"bar": 3}]}, - {"id": 3, "child": [{"bar": 10}, {"bar": 11}]}, + { + "id": 1, + "child": [ + {"bar": 1, "grandchild": [{"baz": 1}, {"baz": 2}]}, + {"bar": 2, "grandchild": [{"baz": 3}]}, + ], + }, + {"id": 2, "child": [{"bar": 3, "grandchild": [{"baz": 4}]}]}, + { + "id": 3, + "child": [ + {"bar": 10, "grandchild": [{"baz": 5}]}, + {"bar": 11, "grandchild": [{"baz": 6}, {"baz": 7}]}, + ], + }, ] info = pipeline.run(identity_resource(run_1)) assert_load_info(info) run_2 = [ - {"id": 1, "child": [{"bar": 1}]}, # Removed one child. - {"id": 2, "child": [{"bar": 4}]}, # Changed child. + { + "id": 1, + "child": [{"bar": 1, "grandchild": [{"baz": 1}]}], + }, # Removed one child and one grandchild + { + "id": 2, + "child": [{"bar": 4, "grandchild": [{"baz": 8}]}], + }, # Changed child and grandchild ] info = pipeline.run(identity_resource(run_2)) assert_load_info(info) @@ -69,15 +87,48 @@ def identity_resource( ] ) - embeddings_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] + expected_grandchild_data = pd.DataFrame( + data=[ + {"baz": 1}, + {"baz": 8}, + {"baz": 5}, + {"baz": 6}, + {"baz": 7}, + ] + ) + + child_table_name = client.make_qualified_table_name("parent__child") # type: ignore[attr-defined] + grandchild_table_name = client.make_qualified_table_name( # type: ignore[attr-defined] + "parent__child__grandchild" + ) - tbl = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] - actual_df = tbl.to_pandas() + child_tbl = client.db_client.open_table(child_table_name) # type: ignore[attr-defined] + grandchild_tbl = client.db_client.open_table(grandchild_table_name) # type: ignore[attr-defined] + + actual_child_df = ( + child_tbl.to_pandas() + .sort_values(by="bar") + .reset_index(drop=True) + .reset_index(drop=True) + ) + actual_grandchild_df = ( + grandchild_tbl.to_pandas() + .sort_values(by="baz") + .reset_index(drop=True) + .reset_index(drop=True) + ) - expected_child_data = expected_child_data.sort_values(by="bar") - actual_df = actual_df.sort_values(by="bar").reset_index(drop=True) + expected_child_data = expected_child_data.sort_values(by="bar").reset_index( + drop=True + ) + expected_grandchild_data = expected_grandchild_data.sort_values( + by="baz" + ).reset_index(drop=True) - assert_frame_equal(actual_df[["bar"]], expected_child_data) + assert_frame_equal(actual_child_df[["bar"]], expected_child_data) + print(actual_grandchild_df[["baz"]]) + print(expected_grandchild_data) + assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) def test_lancedb_remove_orphaned_records_root_table() -> None: @@ -118,14 +169,18 @@ def identity_resource( assert_load_info(info) with pipeline.destination_client() as client: - expected_root_table_df = pd.DataFrame( - data=[ - {"doc_id": 1, "chunk_hash": "1a"}, - {"doc_id": 2, "chunk_hash": "2d"}, - {"doc_id": 2, "chunk_hash": "2e"}, - {"doc_id": 3, "chunk_hash": "3b"}, - ] - ).sort_values(by=["doc_id", "chunk_hash"]) + expected_root_table_df = ( + pd.DataFrame( + data=[ + {"doc_id": 1, "chunk_hash": "1a"}, + {"doc_id": 2, "chunk_hash": "2d"}, + {"doc_id": 2, "chunk_hash": "2e"}, + {"doc_id": 3, "chunk_hash": "3b"}, + ] + ) + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) + ) root_table_name = ( client.make_qualified_table_name("root") # type: ignore[attr-defined] From 6b363d1b6601d795e71a8a29b8501c4958cf3d69 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 1 Aug 2024 22:20:53 +0200 Subject: [PATCH 23/68] Use doc id hint for top level tables Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 178 ++++++++++++++---- dlt/destinations/impl/lancedb/utils.py | 4 +- .../lancedb/test_remove_orphaned_records.py | 15 +- 3 files changed, 149 insertions(+), 48 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 40da46f199..d2cc0982fb 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -12,6 +12,7 @@ Dict, Sequence, TYPE_CHECKING, + Set, ) import lancedb # type: ignore @@ -22,7 +23,7 @@ from lancedb.query import LanceQueryBuilder # type: ignore from lancedb.table import Table # type: ignore from numpy import ndarray -from pyarrow import Array, ChunkedArray, ArrowInvalid +from pyarrow import Array, ChunkedArray, ArrowInvalid, Table from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext @@ -58,7 +59,10 @@ from dlt.destinations.impl.lancedb.exceptions import ( lancedb_error, ) -from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + VECTORIZE_HINT, + DOCUMENT_ID_HINT, +) from dlt.destinations.impl.lancedb.schema import ( make_arrow_field_schema, make_arrow_table_schema, @@ -80,7 +84,9 @@ NDArray = ndarray TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = { + v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items() +} class LanceDBTypeMapper(TypeMapper): @@ -187,7 +193,9 @@ def upload_batch( tbl.add(records, mode="overwrite") elif write_disposition == "merge": if not id_field_name: - raise ValueError("To perform a merge update, 'id_field_name' must be specified.") + raise ValueError( + "To perform a merge update, 'id_field_name' must be specified." + ) tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) @@ -276,7 +284,9 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + def create_table( + self, table_name: str, schema: TArrowSchema, mode: str = "create" + ) -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -299,7 +309,9 @@ def delete_table(self, table_name: str) -> None: def query_table( self, table_name: str, - query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, + query: Union[ + List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None + ] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. @@ -327,7 +339,11 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name != self.sentinel_table] + return [ + table_name + for table_name in table_names + if table_name != self.sentinel_table + ] @lancedb_error def drop_storage(self) -> None: @@ -380,7 +396,9 @@ def update_stored_schema( applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + schema_info = self.get_stored_schema_by_hash( + self.schema.stored_version_hash + ) except DestinationUndefinedEntity: schema_info = None @@ -435,19 +453,25 @@ def add_table_fields( # Check if any of the new fields already exist in the table. existing_fields = set(arrow_table.schema.names) - new_fields = [field for field in field_schemas if field.name not in existing_fields] + new_fields = [ + field for field in field_schemas if field.name not in existing_fields + ] if not new_fields: # All fields already present, skip. return None - null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + null_arrays = [ + pa.nulls(len(arrow_table), type=field.type) for field in new_fields + ] for field, null_array in zip(new_fields, null_arrays): arrow_table = arrow_table.append_column(field, null_array) try: - return self.db_client.create_table(table_name, arrow_table, mode="overwrite") + return self.db_client.create_table( + table_name, arrow_table, mode="overwrite" + ) except OSError: # Error occurred while creating the table, skip. return None @@ -460,11 +484,15 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") + logger.info( + f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" + ) if len(new_columns) > 0: if exists: field_schemas: List[TArrowField] = [ - make_arrow_field_schema(column["name"], column, self.type_mapper) + make_arrow_field_schema( + column["name"], column, self.type_mapper + ) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) @@ -477,7 +505,9 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func - embedding_model_dimensions = self.config.embedding_model_dimensions + embedding_model_dimensions = ( + self.config.embedding_model_dimensions + ) else: embedding_fields = None vector_field_name = None @@ -508,8 +538,12 @@ def update_schema_in_storage(self) -> None: self.schema.naming.normalize_identifier( "engine_version" ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("inserted_at"): str( + pendulum.now() + ), + self.schema.naming.normalize_identifier( + "schema_name" + ): self.schema.name, self.schema.naming.normalize_identifier( "version_hash" ): self.schema.stored_version_hash, @@ -518,7 +552,9 @@ def update_schema_in_storage(self) -> None: ), } ] - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) @@ -533,8 +569,12 @@ def update_schema_in_storage(self) -> None: @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Retrieves the latest completed state for a pipeline.""" - fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_state_table_name = self.make_qualified_table_name( + self.schema.state_table_name + ) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) state_table_: Table = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() @@ -560,7 +600,9 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) .to_arrow() ) - loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + loads_table = ( + loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() + ) # Join arrow tables in-memory. joined_table: pa.Table = state_table.join( @@ -582,8 +624,12 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) @lancedb_error - def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + def get_stored_schema_by_hash( + self, schema_hash: str + ) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -601,7 +647,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI ) ).to_list() - most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + most_recent_schema = sorted( + schemas, key=lambda x: x[p_inserted_at], reverse=True + )[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -616,7 +664,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI @lancedb_error def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" - fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) + fq_version_table_name = self.make_qualified_table_name( + self.schema.version_table_name + ) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -634,7 +684,9 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: ) ).to_list() - most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] + most_recent_schema = sorted( + schemas, key=lambda x: x[p_inserted_at], reverse=True + )[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -662,15 +714,21 @@ def complete_load(self, load_id: str) -> None: records = [ { self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier( + "schema_name" + ): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("inserted_at"): str( + pendulum.now() + ), self.schema.naming.normalize_identifier( "schema_version_hash" ): None, # Payload schema must match the target schema. } ] - fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) + fq_loads_table_name = self.make_qualified_table_name( + self.schema.loads_table_name + ) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) @@ -684,7 +742,9 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def start_file_load( + self, table: TTableSchema, file_path: str, load_id: str + ) -> LoadJob: parent_table = table.get("parent") return LoadLanceDBJob( @@ -712,6 +772,9 @@ def create_table_chain_completed_followup_jobs( ) for table in table_chain: + if table in self.schema.dlt_tables(): + continue + # Only tables with merge disposition are dispatched for orphan removal jobs. if table.get("write_disposition") == "merge": parent_table = table.get("parent") @@ -721,8 +784,11 @@ def create_table_chain_completed_followup_jobs( table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name(table["name"]), fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None + self.make_qualified_table_name(parent_table) + if parent_table + else None ), + client_config=self.config, ) ) @@ -756,8 +822,10 @@ def __init__( self.table_name: str = table_schema["name"] self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) + self.unique_identifiers: List[str] = list_merge_identifiers(table_schema) + self.embedding_fields: List[str] = get_columns_names_with_prop( + table_schema, VECTORIZE_HINT + ) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name @@ -804,6 +872,7 @@ def __init__( self, db_client: DBConnection, table_schema: TTableSchema, + client_config: LanceDBClientConfiguration, fq_table_name: str, fq_parent_table_name: Optional[str], ) -> None: @@ -814,6 +883,7 @@ def __init__( self.write_disposition: TWriteDisposition = cast( TWriteDisposition, self.table_schema.get("write_disposition") ) + self.id_field_name: str = client_config.id_field_name job_id = ParsedLoadJobFileName( table_schema["name"], @@ -832,6 +902,8 @@ def __init__( self.execute() def execute(self) -> None: + orphaned_ids: Set[str] + if self.write_disposition != "merge": raise DestinationTerminalException( f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" @@ -851,22 +923,50 @@ def execute(self) -> None: ) from e try: - # Chunks in child table. if self.fq_parent_table_name: - parent_ids = set(pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist()) - child_ids = set(pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist()) + # Chunks and embeddings in child table. + parent_ids = set( + pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist() + ) + child_ids = set( + pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist() + ) if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: child_table.delete( - "_dlt_parent_id IN" - f" {tuple(orphaned_ids) if len(orphaned_ids) > 1 else orphaned_ids.pop()}" + "_dlt_parent_id IN" f" {tuple(orphaned_ids)}" ) elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") - # Chunks in root table. TODO: Add test for embeddings in root table - else: ... + else: + # Chunks and embeddings in the root table. + child_table_arrow: pa.Table = child_table.to_arrow() + + # If document ID is defined, we use this as the sole grouping key to identify stale chunks, + # else fallback to the compound `id_field_name`. + grouping_key = ( + get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) + or self.id_field_name + ) + + grouped = child_table_arrow.group_by(grouping_key).aggregate( + [("_dlt_load_id", "max")] + ) + joined = child_table_arrow.join(grouped, keys=grouping_key) + orphaned_mask = pc.not_equal( + joined["_dlt_load_id"], joined["_dlt_load_id_max"] + ) + orphaned_ids = ( + joined.filter(orphaned_mask).column("_dlt_id").to_pylist() + ) + + if len(orphaned_ids) > 1: + child_table.delete("_dlt_id IN" f" {tuple(orphaned_ids)}") + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") + except ArrowInvalid as e: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 874852512a..1d944ff7a4 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,6 +1,6 @@ import os import uuid -from typing import Sequence, Union, Dict +from typing import Sequence, Union, Dict, List from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop @@ -32,7 +32,7 @@ def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_nam return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) -def list_merge_identifiers(table_schema: TTableSchema) -> Sequence[str]: +def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: """Returns a list of merge keys for a table used for either merging or deduplication. Args: diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 9104399c9b..4679238010 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -8,6 +8,7 @@ import dlt from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id +from dlt.destinations.impl.lancedb.lancedb_adapter import DOCUMENT_ID_HINT from tests.load.utils import ( drop_active_pipeline_data, ) @@ -34,10 +35,10 @@ def test_lancedb_remove_orphaned_records() -> None: dev_mode=True, ) - @dlt.resource( + @dlt.resource( # type: ignore[call-overload] table_name="parent", write_disposition="merge", - primary_key="id", + columns={"id": {DOCUMENT_ID_HINT: True}}, ) def identity_resource( data: List[DictStrAny], @@ -126,8 +127,6 @@ def identity_resource( ).reset_index(drop=True) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) - print(actual_grandchild_df[["baz"]]) - print(expected_grandchild_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -139,10 +138,11 @@ def test_lancedb_remove_orphaned_records_root_table() -> None: dev_mode=True, ) - @dlt.resource( + @dlt.resource( # type: ignore[call-overload] table_name="root", write_disposition="merge", - primary_key="doc_id", + merge_key=["chunk_hash"], + columns={"doc_id": {DOCUMENT_ID_HINT: True}}, ) def identity_resource( data: List[DictStrAny], @@ -191,5 +191,6 @@ def identity_resource( tbl.to_pandas() .sort_values(by=["doc_id", "chunk_hash"]) .reset_index(drop=True) - ) + )[["doc_id", "chunk_hash"]] + assert_frame_equal(actual_root_df, expected_root_table_df) From aac7647ed7ac28cfd67f768ce044d34d452dd181 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 1 Aug 2024 23:03:55 +0200 Subject: [PATCH 24/68] Only join on join columns for orphan removal job Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 145 +++++------------- dlt/destinations/impl/lancedb/utils.py | 4 +- .../lancedb/test_remove_orphaned_records.py | 14 +- 3 files changed, 44 insertions(+), 119 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index d2cc0982fb..cc35aadf18 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -84,9 +84,7 @@ NDArray = ndarray TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = { - v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items() -} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} class LanceDBTypeMapper(TypeMapper): @@ -193,9 +191,7 @@ def upload_batch( tbl.add(records, mode="overwrite") elif write_disposition == "merge": if not id_field_name: - raise ValueError( - "To perform a merge update, 'id_field_name' must be specified." - ) + raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(records) @@ -284,9 +280,7 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table( - self, table_name: str, schema: TArrowSchema, mode: str = "create" - ) -> Table: + def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -309,9 +303,7 @@ def delete_table(self, table_name: str) -> None: def query_table( self, table_name: str, - query: Union[ - List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None - ] = None, + query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. @@ -339,11 +331,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [ - table_name - for table_name in table_names - if table_name != self.sentinel_table - ] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -396,9 +384,7 @@ def update_stored_schema( applied_update: TSchemaTables = {} try: - schema_info = self.get_stored_schema_by_hash( - self.schema.stored_version_hash - ) + schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) except DestinationUndefinedEntity: schema_info = None @@ -453,25 +439,19 @@ def add_table_fields( # Check if any of the new fields already exist in the table. existing_fields = set(arrow_table.schema.names) - new_fields = [ - field for field in field_schemas if field.name not in existing_fields - ] + new_fields = [field for field in field_schemas if field.name not in existing_fields] if not new_fields: # All fields already present, skip. return None - null_arrays = [ - pa.nulls(len(arrow_table), type=field.type) for field in new_fields - ] + null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] for field, null_array in zip(new_fields, null_arrays): arrow_table = arrow_table.append_column(field, null_array) try: - return self.db_client.create_table( - table_name, arrow_table, mode="overwrite" - ) + return self.db_client.create_table(table_name, arrow_table, mode="overwrite") except OSError: # Error occurred while creating the table, skip. return None @@ -484,15 +464,11 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) - logger.info( - f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}" - ) + logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: field_schemas: List[TArrowField] = [ - make_arrow_field_schema( - column["name"], column, self.type_mapper - ) + make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) @@ -505,9 +481,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func - embedding_model_dimensions = ( - self.config.embedding_model_dimensions - ) + embedding_model_dimensions = self.config.embedding_model_dimensions else: embedding_fields = None vector_field_name = None @@ -538,12 +512,8 @@ def update_schema_in_storage(self) -> None: self.schema.naming.normalize_identifier( "engine_version" ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str( - pendulum.now() - ), - self.schema.naming.normalize_identifier( - "schema_name" - ): self.schema.name, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier( "version_hash" ): self.schema.stored_version_hash, @@ -552,9 +522,7 @@ def update_schema_in_storage(self) -> None: ), } ] - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) @@ -569,12 +537,8 @@ def update_schema_in_storage(self) -> None: @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: """Retrieves the latest completed state for a pipeline.""" - fq_state_table_name = self.make_qualified_table_name( - self.schema.state_table_name - ) - fq_loads_table_name = self.make_qualified_table_name( - self.schema.loads_table_name - ) + fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) state_table_: Table = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() @@ -600,9 +564,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) .to_arrow() ) - loads_table = ( - loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() - ) + loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() # Join arrow tables in-memory. joined_table: pa.Table = state_table.join( @@ -624,12 +586,8 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: ) @lancedb_error - def get_stored_schema_by_hash( - self, schema_hash: str - ) -> Optional[StorageSchemaInfo]: - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -647,9 +605,7 @@ def get_stored_schema_by_hash( ) ).to_list() - most_recent_schema = sorted( - schemas, key=lambda x: x[p_inserted_at], reverse=True - )[0] + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -664,9 +620,7 @@ def get_stored_schema_by_hash( @lancedb_error def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" - fq_version_table_name = self.make_qualified_table_name( - self.schema.version_table_name - ) + fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) version_table: Table = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() @@ -684,9 +638,7 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: ) ).to_list() - most_recent_schema = sorted( - schemas, key=lambda x: x[p_inserted_at], reverse=True - )[0] + most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] return StorageSchemaInfo( version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], @@ -714,21 +666,15 @@ def complete_load(self, load_id: str) -> None: records = [ { self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier( - "schema_name" - ): self.schema.name, + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str( - pendulum.now() - ), + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), self.schema.naming.normalize_identifier( "schema_version_hash" ): None, # Payload schema must match the target schema. } ] - fq_loads_table_name = self.make_qualified_table_name( - self.schema.loads_table_name - ) + fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) @@ -742,9 +688,7 @@ def complete_load(self, load_id: str) -> None: def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def start_file_load( - self, table: TTableSchema, file_path: str, load_id: str - ) -> LoadJob: + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: parent_table = table.get("parent") return LoadLanceDBJob( @@ -784,9 +728,7 @@ def create_table_chain_completed_followup_jobs( table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name(table["name"]), fq_parent_table_name=( - self.make_qualified_table_name(parent_table) - if parent_table - else None + self.make_qualified_table_name(parent_table) if parent_table else None ), client_config=self.config, ) @@ -823,9 +765,7 @@ def __init__( self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name self.unique_identifiers: List[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop( - table_schema, VECTORIZE_HINT - ) + self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name @@ -925,24 +865,17 @@ def execute(self) -> None: try: if self.fq_parent_table_name: # Chunks and embeddings in child table. - parent_ids = set( - pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist() - ) - child_ids = set( - pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist() - ) + parent_ids = set(pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist()) + child_ids = set(pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist()) if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: - child_table.delete( - "_dlt_parent_id IN" f" {tuple(orphaned_ids)}" - ) + child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: # Chunks and embeddings in the root table. - child_table_arrow: pa.Table = child_table.to_arrow() # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. @@ -950,20 +883,20 @@ def execute(self) -> None: get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) or self.id_field_name ) + grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] + child_table_arrow: pa.Table = child_table.to_arrow().select( + [*grouping_key, "_dlt_load_id", "_dlt_id"] + ) grouped = child_table_arrow.group_by(grouping_key).aggregate( [("_dlt_load_id", "max")] ) joined = child_table_arrow.join(grouped, keys=grouping_key) - orphaned_mask = pc.not_equal( - joined["_dlt_load_id"], joined["_dlt_load_id_max"] - ) - orphaned_ids = ( - joined.filter(orphaned_mask).column("_dlt_id").to_pylist() - ) + orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) + orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() if len(orphaned_ids) > 1: - child_table.delete("_dlt_id IN" f" {tuple(orphaned_ids)}") + child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 1d944ff7a4..f202903598 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -6,7 +6,6 @@ from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider -from dlt.destinations.impl.lancedb.lancedb_adapter import DOCUMENT_ID_HINT PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -42,10 +41,9 @@ def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: Sequence[str]: A list of unique column identifiers. """ if table_schema.get("write_disposition") == "merge": - document_id = get_columns_names_with_prop(table_schema, DOCUMENT_ID_HINT) primary_keys = get_columns_names_with_prop(table_schema, "primary_key") merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys + document_id)): + if join_keys := list(set(primary_keys + merge_keys)): return join_keys return get_columns_names_with_prop(table_schema, "unique") diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 4679238010..285ecec577 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -119,12 +119,10 @@ def identity_resource( .reset_index(drop=True) ) - expected_child_data = expected_child_data.sort_values(by="bar").reset_index( + expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) + expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( drop=True ) - expected_grandchild_data = expected_grandchild_data.sort_values( - by="baz" - ).reset_index(drop=True) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -182,15 +180,11 @@ def identity_resource( .reset_index(drop=True) ) - root_table_name = ( - client.make_qualified_table_name("root") # type: ignore[attr-defined] - ) + root_table_name = client.make_qualified_table_name("root") # type: ignore[attr-defined] tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] actual_root_df: DataFrame = ( - tbl.to_pandas() - .sort_values(by=["doc_id", "chunk_hash"]) - .reset_index(drop=True) + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) )[["doc_id", "chunk_hash"]] assert_frame_equal(actual_root_df, expected_root_table_df) From e33b7cffe95e444ced4300eeec84add4a2d92bad Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 1 Aug 2024 23:51:15 +0200 Subject: [PATCH 25/68] Add ollama to supported embedding providers and test orphaned record removal with embeddings Signed-off-by: Marcel Coetzee --- .../impl/lancedb/configuration.py | 1 + .../lancedb/test_remove_orphaned_records.py | 95 ++++++++++++++++++- 2 files changed, 91 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index ba3a8b49d9..5aa4ba714f 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -59,6 +59,7 @@ class LanceDBClientOptions(BaseConfiguration): "sentence-transformers", "huggingface", "colbert", + "ollama", ] diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 285ecec577..2c25a315d4 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -1,14 +1,21 @@ -from typing import Iterator, List, Generator +from typing import Iterator, List, Generator, Any +import numpy as np import pandas as pd import pytest +from lancedb.table import Table # type: ignore from pandas import DataFrame from pandas.testing import assert_frame_equal +from pyarrow import Table import dlt from dlt.common.typing import DictStrAny from dlt.common.utils import uniq_id -from dlt.destinations.impl.lancedb.lancedb_adapter import DOCUMENT_ID_HINT +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + DOCUMENT_ID_HINT, + lancedb_adapter, +) +from tests.load.lancedb.utils import chunk_document from tests.load.utils import ( drop_active_pipeline_data, ) @@ -119,10 +126,12 @@ def identity_resource( .reset_index(drop=True) ) - expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) - expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( + expected_child_data = expected_child_data.sort_values(by="bar").reset_index( drop=True ) + expected_grandchild_data = expected_grandchild_data.sort_values( + by="baz" + ).reset_index(drop=True) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -184,7 +193,83 @@ def identity_resource( tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] actual_root_df: DataFrame = ( - tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) + tbl.to_pandas() + .sort_values(by=["doc_id", "chunk_hash"]) + .reset_index(drop=True) )[["doc_id", "chunk_hash"]] assert_frame_equal(actual_root_df, expected_root_table_df) + + +def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: + @dlt.resource( # type: ignore + write_disposition="merge", + table_name="document", + columns={"doc_id": {DOCUMENT_ID_HINT: True}}, + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + for chunk in chunk_document(doc["text"]): + yield {"doc_id": doc_id, "doc_text": doc["text"], "chunk": chunk} + + @dlt.source() + def documents_source( + docs: List[DictStrAny], + ) -> Any: + return documents(docs) + + lancedb_adapter( + documents, + embed=["chunk"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_lancedb_remove_orphaned_records_with_embeddings", + destination="lancedb", + dataset_name=f"test_lancedb_remove_orphaned_records_{uniq_id()}", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + info = pipeline.run(documents_source(initial_docs)) + assert_load_info(info) + + updated_docs = [ + { + "text": "This is the first document, but it has been updated with new content.", + "doc_id": 1, + }, + { + "text": "This is a completely new document that wasn't in the initial set.", + "doc_id": 3, + }, + ] + + info = pipeline.run(documents_source(updated_docs)) + assert_load_info(info) + + with pipeline.destination_client() as client: + embeddings_table_name = client.make_qualified_table_name("document") # type: ignore[attr-defined] + tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] + df = tbl.to_pandas() + + # Check (non-empty) embeddings as present, and that orphaned embeddings have been discarded. + assert len(df) == 21 + assert "vector__" in df.columns + for _, vector in enumerate(df["vector__"]): + assert isinstance(vector, np.ndarray) + assert vector.size > 0 From f2913e9324af4a149a6579d68fa9676ae3ef7ca6 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 2 Aug 2024 19:36:24 +0200 Subject: [PATCH 26/68] Add merge_key to document resource for efficient updates in LanceDB Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_remove_orphaned_records.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 2c25a315d4..f4352ed71e 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -205,6 +205,7 @@ def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> No @dlt.resource( # type: ignore write_disposition="merge", table_name="document", + merge_key=["chunk"], columns={"doc_id": {DOCUMENT_ID_HINT: True}}, ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: From ffe6584a412e047264f69d4d58f6c25ec9bf1908 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 2 Aug 2024 19:51:53 +0200 Subject: [PATCH 27/68] Formatting Signed-off-by: Marcel Coetzee --- tests/load/lancedb/test_remove_orphaned_records.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index f4352ed71e..3118b06cc7 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -126,12 +126,10 @@ def identity_resource( .reset_index(drop=True) ) - expected_child_data = expected_child_data.sort_values(by="bar").reset_index( + expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True) + expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index( drop=True ) - expected_grandchild_data = expected_grandchild_data.sort_values( - by="baz" - ).reset_index(drop=True) assert_frame_equal(actual_child_df[["bar"]], expected_child_data) assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data) @@ -193,9 +191,7 @@ def identity_resource( tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined] actual_root_df: DataFrame = ( - tbl.to_pandas() - .sort_values(by=["doc_id", "chunk_hash"]) - .reset_index(drop=True) + tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True) )[["doc_id", "chunk_hash"]] assert_frame_equal(actual_root_df, expected_root_table_df) From 036801829879d34d54052755560debb2f6c3e1f8 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 2 Aug 2024 20:03:09 +0200 Subject: [PATCH 28/68] Set default file size to 128MB Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index f2e17168b9..714c8dcdfd 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -30,6 +30,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.decimal_precision = (38, 18) caps.timestamp_precision = 6 + caps.recommended_file_size = 128_000_000 + return caps @property From 02704d50d60dece8e5cc1abf6aa5b63edd561a7a Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 3 Aug 2024 23:37:34 +0200 Subject: [PATCH 29/68] Only use parquet loader file formats Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 4 ++-- dlt/destinations/impl/lancedb/lancedb_client.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index 714c8dcdfd..9acb82344b 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -16,8 +16,8 @@ class lancedb(Destination[LanceDBClientConfiguration, "LanceDBClient"]): def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl"] + caps.preferred_loader_file_format = "parquet" + caps.supported_loader_file_formats = ["parquet"] caps.max_identifier_length = 200 caps.max_column_identifier_length = 1024 diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index cc35aadf18..76c3feb032 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -773,8 +773,8 @@ def __init__( TWriteDisposition, self.table_schema.get("write_disposition", "append") ) - with FileStorage.open_zipsafe_ro(local_path) as f: - records: List[DictStrAny] = [json.loads(line) for line in f] + with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: + arrow_table: pa.Table = pq.read_table(f, memory_map=True) if self.table_schema not in self.schema.dlt_tables(): for record in records: From eae056a5ea0f7c0ef738a0a835fb1ffab260d412 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 4 Aug 2024 17:58:01 +0200 Subject: [PATCH 30/68] Import pyarrow.parquet Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 523 ++++-------------- 1 file changed, 106 insertions(+), 417 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 76c3feb032..04f0e2cac0 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,23 +1,11 @@ import uuid from types import TracebackType -from typing import ( - List, - Any, - cast, - Union, - Tuple, - Iterable, - Type, - Optional, - Dict, - Sequence, - TYPE_CHECKING, - Set, -) +from typing import (List, Any, cast, Union, Tuple, Iterable, Type, Optional, Dict, Sequence, TYPE_CHECKING, Set, ) import lancedb # type: ignore import pyarrow as pa import pyarrow.compute as pc +import pyarrow.parquet as pq from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore @@ -27,54 +15,23 @@ from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.exceptions import ( - DestinationUndefinedEntity, - DestinationTransientException, - DestinationTerminalException, -) -from dlt.common.destination.reference import ( - JobClientBase, - WithStateSync, - LoadJob, - StorageSchemaInfo, - StateInfo, - TLoadJobState, - NewLoadJob, - FollowupJob, -) +from dlt.common.destination.exceptions import (DestinationUndefinedEntity, DestinationTransientException, + DestinationTerminalException, ) +from dlt.common.destination.reference import (JobClientBase, WithStateSync, LoadJob, StorageSchemaInfo, StateInfo, + TLoadJobState, NewLoadJob, FollowupJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.typing import ( - TColumnType, - TTableFormat, - TTableSchemaColumns, - TWriteDisposition, -) +from dlt.common.schema.typing import (TColumnType, TTableFormat, TTableSchemaColumns, TWriteDisposition, ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny -from dlt.destinations.impl.lancedb.configuration import ( - LanceDBClientConfiguration, -) -from dlt.destinations.impl.lancedb.exceptions import ( - lancedb_error, -) -from dlt.destinations.impl.lancedb.lancedb_adapter import ( - VECTORIZE_HINT, - DOCUMENT_ID_HINT, -) -from dlt.destinations.impl.lancedb.schema import ( - make_arrow_field_schema, - make_arrow_table_schema, - TArrowSchema, - NULL_SCHEMA, - TArrowField, -) -from dlt.destinations.impl.lancedb.utils import ( - list_merge_identifiers, - generate_uuid, - set_non_standard_providers_environment_variables, -) +from dlt.destinations.impl.lancedb.configuration import (LanceDBClientConfiguration, ) +from dlt.destinations.impl.lancedb.exceptions import (lancedb_error, ) +from dlt.destinations.impl.lancedb.lancedb_adapter import (VECTORIZE_HINT, DOCUMENT_ID_HINT, ) +from dlt.destinations.impl.lancedb.schema import (make_arrow_field_schema, make_arrow_table_schema, TArrowSchema, + NULL_SCHEMA, TArrowField, ) +from dlt.destinations.impl.lancedb.utils import (list_merge_identifiers, generate_uuid, + set_non_standard_providers_environment_variables, ) from dlt.destinations.job_impl import EmptyLoadJob, NewLoadJobImpl from dlt.destinations.type_mapping import TypeMapper @@ -88,80 +45,38 @@ class LanceDBTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "text": pa.string(), - "double": pa.float64(), - "bool": pa.bool_(), - "bigint": pa.int64(), - "binary": pa.binary(), - "date": pa.date32(), - "complex": pa.string(), - } + sct_to_unbound_dbt = {"text": pa.string(), "double": pa.float64(), "bool": pa.bool_(), "bigint": pa.int64(), "binary": pa.binary(), "date": pa.date32(), "complex": pa.string(), } sct_to_dbt = {} - dbt_to_sct = { - pa.string(): "text", - pa.float64(): "double", - pa.bool_(): "bool", - pa.int64(): "bigint", - pa.binary(): "binary", - pa.date32(): "date", - } - - def to_db_decimal_type( - self, precision: Optional[int], scale: Optional[int] - ) -> pa.Decimal128Type: + dbt_to_sct = {pa.string(): "text", pa.float64(): "double", pa.bool_(): "bool", pa.int64(): "bigint", pa.binary(): "binary", pa.date32(): "date", } + + def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> pa.Decimal128Type: precision, scale = self.decimal_precision(precision, scale) return pa.decimal128(precision, scale) - def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> pa.TimestampType: + def to_db_datetime_type(self, precision: Optional[int], table_format: TTableFormat = None) -> pa.TimestampType: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.timestamp(unit, "UTC") - def to_db_time_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> pa.Time64Type: + def to_db_time_type(self, precision: Optional[int], table_format: TTableFormat = None) -> pa.Time64Type: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.time64(unit) - def from_db_type( - self, - db_type: pa.DataType, - precision: Optional[int] = None, - scale: Optional[int] = None, - ) -> TColumnType: + def from_db_type(self, db_type: pa.DataType, precision: Optional[int] = None, scale: Optional[int] = None, ) -> TColumnType: if isinstance(db_type, pa.TimestampType): - return dict( - data_type="timestamp", - precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], - scale=scale, - ) + return dict(data_type="timestamp", precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], scale=scale, ) if isinstance(db_type, pa.Time64Type): - return dict( - data_type="time", - precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], - scale=scale, - ) + return dict(data_type="time", precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], scale=scale, ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale) == self.capabilities.wei_precision: + if (precision, scale)==self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) -def upload_batch( - records: List[DictStrAny], - /, - *, - db_client: DBConnection, - table_name: str, - write_disposition: Optional[TWriteDisposition] = "append", - id_field_name: Optional[str] = None, -) -> None: +def upload_batch(records: List[DictStrAny], /, *, db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. Args: @@ -180,30 +95,22 @@ def upload_batch( tbl = db_client.open_table(table_name) tbl.checkout_latest() except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Batch WILL BE RETRIED" - ) from e + raise DestinationTransientException("Couldn't open lancedb database. Batch WILL BE RETRIED") from e try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition == "replace": + elif write_disposition=="replace": tbl.add(records, mode="overwrite") - elif write_disposition == "merge": + elif write_disposition=="merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") - tbl.merge_insert( - id_field_name - ).when_matched_update_all().when_not_matched_insert_all().execute(records) + tbl.merge_insert(id_field_name).when_matched_update_all().when_not_matched_insert_all().execute(records) else: - raise DestinationTerminalException( - f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" - " failed AND WILL **NOT** BE RETRIED." - ) + raise DestinationTerminalException(f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" + " failed AND WILL **NOT** BE RETRIED.") except ArrowInvalid as e: - raise DestinationTerminalException( - "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." - ) from e + raise DestinationTerminalException("Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED.") from e class LanceDBClient(JobClientBase, WithStateSync): @@ -211,19 +118,10 @@ class LanceDBClient(JobClientBase, WithStateSync): model_func: TextEmbeddingFunction - def __init__( - self, - schema: Schema, - config: LanceDBClientConfiguration, - capabilities: DestinationCapabilitiesContext, - ) -> None: + def __init__(self, schema: Schema, config: LanceDBClientConfiguration, capabilities: DestinationCapabilitiesContext, ) -> None: super().__init__(schema, config, capabilities) self.config: LanceDBClientConfiguration = config - self.db_client: DBConnection = lancedb.connect( - uri=self.config.credentials.uri, - api_key=self.config.credentials.api_key, - read_consistency_interval=timedelta(0), - ) + self.db_client: DBConnection = lancedb.connect(uri=self.config.credentials.uri, api_key=self.config.credentials.api_key, read_consistency_interval=timedelta(0), ) self.registry = EmbeddingFunctionRegistry.get_instance() self.type_mapper = LanceDBTypeMapper(self.capabilities) self.sentinel_table_name = config.sentinel_table_name @@ -233,24 +131,14 @@ def __init__( # LanceDB doesn't provide a standardized way to set API keys across providers. # Some use ENV variables and others allow passing api key as an argument. # To account for this, we set provider environment variable as well. - set_non_standard_providers_environment_variables( - embedding_model_provider, - self.config.credentials.embedding_model_provider_api_key, - ) + set_non_standard_providers_environment_variables(embedding_model_provider, self.config.credentials.embedding_model_provider_api_key, ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider == "openai": + if embedding_model_provider=="openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings - self.model_func = PatchedOpenAIEmbeddings( - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) + self.model_func = PatchedOpenAIEmbeddings(max_retries=self.config.options.max_retries, api_key=self.config.credentials.api_key, ) else: - self.model_func = self.registry.get(embedding_model_provider).create( - name=self.config.embedding_model, - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) + self.model_func = self.registry.get(embedding_model_provider).create(name=self.config.embedding_model, max_retries=self.config.options.max_retries, api_key=self.config.credentials.api_key, ) self.vector_field_name = self.config.vector_field_name self.id_field_name = self.config.id_field_name @@ -264,20 +152,13 @@ def sentinel_table(self) -> str: return self.make_qualified_table_name(self.sentinel_table_name) def make_qualified_table_name(self, table_name: str) -> str: - return ( - f"{self.dataset_name}{self.config.dataset_separator}{table_name}" - if self.dataset_name - else table_name - ) + return (f"{self.dataset_name}{self.config.dataset_separator}{table_name}" if self.dataset_name else table_name) def get_table_schema(self, table_name: str) -> TArrowSchema: schema_table: Table = self.db_client.open_table(table_name) schema_table.checkout_latest() schema = schema_table.schema - return cast( - TArrowSchema, - schema, - ) + return cast(TArrowSchema, schema, ) @lancedb_error def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: @@ -300,11 +181,7 @@ def delete_table(self, table_name: str) -> None: """ self.db_client.drop_table(table_name) - def query_table( - self, - table_name: str, - query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, - ) -> LanceQueryBuilder: + def query_table(self, table_name: str, query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. Args: @@ -323,15 +200,11 @@ def _get_table_names(self) -> List[str]: """Return all tables in the dataset, excluding the sentinel table.""" if self.dataset_name: prefix = f"{self.dataset_name}{self.config.dataset_separator}" - table_names = [ - table_name - for table_name in self.db_client.table_names() - if table_name.startswith(prefix) - ] + table_names = [table_name for table_name in self.db_client.table_names() if table_name.startswith(prefix)] else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name != self.sentinel_table] + return [table_name for table_name in table_names if table_name!=self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -357,10 +230,7 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: continue schema = self.get_table_schema(fq_table_name) self.db_client.drop_table(fq_table_name) - self.create_table( - table_name=fq_table_name, - schema=schema, - ) + self.create_table(table_name=fq_table_name, schema=schema, ) @lancedb_error def is_storage_initialized(self) -> bool: @@ -375,11 +245,7 @@ def _delete_sentinel_table(self) -> None: self.db_client.drop_table(self.sentinel_table) @lancedb_error - def update_stored_schema( - self, - only_tables: Iterable[str] = None, - expected_update: TSchemaTables = None, - ) -> Optional[TSchemaTables]: + def update_stored_schema(self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None, ) -> Optional[TSchemaTables]: super().update_stored_schema(only_tables, expected_update) applied_update: TSchemaTables = {} @@ -389,17 +255,13 @@ def update_stored_schema( schema_info = None if schema_info is None: - logger.info( - f"Schema with hash {self.schema.stored_version_hash} " - "not found in the storage. upgrading" - ) + logger.info(f"Schema with hash {self.schema.stored_version_hash} " + "not found in the storage. upgrading") self._execute_schema_update(only_tables) else: - logger.info( - f"Schema with hash {self.schema.stored_version_hash} " - f"inserted at {schema_info.inserted_at} found " - "in storage, no upgrade required" - ) + logger.info(f"Schema with hash {self.schema.stored_version_hash} " + f"inserted at {schema_info.inserted_at} found " + "in storage, no upgrade required") return applied_update def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: @@ -417,16 +279,11 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] field: TArrowField for field in arrow_schema: name = self.schema.naming.normalize_identifier(field.name) - table_schema[name] = { - "name": name, - **self.type_mapper.from_db_type(field.type), - } + table_schema[name] = {"name": name, **self.type_mapper.from_db_type(field.type), } return True, table_schema @lancedb_error - def add_table_fields( - self, table_name: str, field_schemas: List[TArrowField] - ) -> Optional[Table]: + def add_table_fields(self, table_name: str, field_schemas: List[TArrowField]) -> Optional[Table]: """Add multiple fields to the LanceDB table at once. Args: @@ -459,25 +316,16 @@ def add_table_fields( def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns( - table_name, - existing_columns, - self.capabilities.generates_case_sensitive_identifiers(), - ) + new_columns = self.schema.get_new_table_columns(table_name, existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: - field_schemas: List[TArrowField] = [ - make_arrow_field_schema(column["name"], column, self.type_mapper) - for column in new_columns - ] + field_schemas: List[TArrowField] = [make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns] fq_table_name = self.make_qualified_table_name(table_name) self.add_table_fields(fq_table_name, field_schemas) else: if table_name not in self.schema.dlt_table_names(): - embedding_fields = get_columns_names_with_prop( - self.schema.get_table(table_name=table_name), VECTORIZE_HINT - ) + embedding_fields = get_columns_names_with_prop(self.schema.get_table(table_name=table_name), VECTORIZE_HINT) vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func @@ -489,16 +337,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: embedding_model_func = None embedding_model_dimensions = None - table_schema: TArrowSchema = make_arrow_table_schema( - table_name, - schema=self.schema, - type_mapper=self.type_mapper, - embedding_fields=embedding_fields, - embedding_model_func=embedding_model_func, - embedding_model_dimensions=embedding_model_dimensions, - vector_field_name=vector_field_name, - id_field_name=id_field_name, - ) + table_schema: TArrowSchema = make_arrow_table_schema(table_name, schema=self.schema, type_mapper=self.type_mapper, embedding_fields=embedding_fields, embedding_model_func=embedding_model_func, embedding_model_dimensions=embedding_model_dimensions, vector_field_name=vector_field_name, id_field_name=id_field_name, ) fq_table_name = self.make_qualified_table_name(table_name) self.create_table(fq_table_name, table_schema) @@ -506,33 +345,13 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: @lancedb_error def update_schema_in_storage(self) -> None: - records = [ - { - self.schema.naming.normalize_identifier("version"): self.schema.version, - self.schema.naming.normalize_identifier( - "engine_version" - ): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, - self.schema.naming.normalize_identifier( - "version_hash" - ): self.schema.stored_version_hash, - self.schema.naming.normalize_identifier("schema"): json.dumps( - self.schema.to_dict() - ), - } - ] + records = [{self.schema.naming.normalize_identifier("version"): self.schema.version, self.schema.naming.normalize_identifier("engine_version"): self.schema.ENGINE_VERSION, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("version_hash"): self.schema.stored_version_hash, self.schema.naming.normalize_identifier("schema"): json.dumps(self.schema.to_dict()), }] fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - write_disposition = self.schema.get_table(self.schema.version_table_name).get( - "write_disposition" - ) + write_disposition = self.schema.get_table(self.schema.version_table_name).get("write_disposition") - upload_batch( - records, - db_client=self.db_client, - table_name=fq_version_table_name, - write_disposition=write_disposition, - ) + upload_batch(records, db_client=self.db_client, table_name=fq_version_table_name, write_disposition=write_disposition, ) @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: @@ -559,31 +378,18 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less # data into memory as possible. - state_table = ( - state_table_.search() - .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) - .to_arrow() - ) + state_table = (state_table_.search().where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True).to_arrow()) loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() # Join arrow tables in-memory. - joined_table: pa.Table = state_table.join( - loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" - ).sort_by([(p_dlt_load_id, "descending")]) + joined_table: pa.Table = state_table.join(loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner").sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows == 0: + if joined_table.num_rows==0: return None state = joined_table.take([0]).to_pylist()[0] - return StateInfo( - version=state[p_version], - engine_version=state[p_engine_version], - pipeline_name=state[p_pipeline_name], - state=state[p_state], - created_at=pendulum.instance(state[p_created_at]), - version_hash=state[p_version_hash], - _dlt_load_id=state[p_dlt_load_id], - ) + return StateInfo(version=state[p_version], engine_version=state[p_engine_version], pipeline_name=state[p_pipeline_name], state=state[p_state], created_at=pendulum.instance( + state[p_created_at]), version_hash=state[p_version_hash], _dlt_load_id=state[p_dlt_load_id], ) @lancedb_error def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: @@ -599,21 +405,11 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = ( - version_table.search().where( - f'`{p_version_hash}` = "{schema_hash}"', prefilter=True - ) - ).to_list() + schemas = (version_table.search().where(f'`{p_version_hash}` = "{schema_hash}"', prefilter=True)).to_list() most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] - return StorageSchemaInfo( - version_hash=most_recent_schema[p_version_hash], - schema_name=most_recent_schema[p_schema_name], - version=most_recent_schema[p_version], - engine_version=most_recent_schema[p_engine_version], - inserted_at=most_recent_schema[p_inserted_at], - schema=most_recent_schema[p_schema], - ) + return StorageSchemaInfo(version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], version=most_recent_schema[p_version], engine_version= + most_recent_schema[p_engine_version], inserted_at=most_recent_schema[p_inserted_at], schema=most_recent_schema[p_schema], ) except IndexError: return None @@ -632,30 +428,15 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = ( - version_table.search().where( - f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True - ) - ).to_list() + schemas = (version_table.search().where(f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True)).to_list() most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] - return StorageSchemaInfo( - version_hash=most_recent_schema[p_version_hash], - schema_name=most_recent_schema[p_schema_name], - version=most_recent_schema[p_version], - engine_version=most_recent_schema[p_engine_version], - inserted_at=most_recent_schema[p_inserted_at], - schema=most_recent_schema[p_schema], - ) + return StorageSchemaInfo(version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], version=most_recent_schema[p_version], engine_version= + most_recent_schema[p_engine_version], inserted_at=most_recent_schema[p_inserted_at], schema=most_recent_schema[p_schema], ) except IndexError: return None - def __exit__( - self, - exc_type: Type[BaseException], - exc_val: BaseException, - exc_tb: TracebackType, - ) -> None: + def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType, ) -> None: pass def __enter__(self) -> "LanceDBClient": @@ -663,27 +444,13 @@ def __enter__(self) -> "LanceDBClient": @lancedb_error def complete_load(self, load_id: str) -> None: - records = [ - { - self.schema.naming.normalize_identifier("load_id"): load_id, - self.schema.naming.normalize_identifier("schema_name"): self.schema.name, - self.schema.naming.normalize_identifier("status"): 0, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier( - "schema_version_hash" - ): None, # Payload schema must match the target schema. - } - ] + records = [{self.schema.naming.normalize_identifier("load_id"): load_id, self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("status"): 0, self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_version_hash"): None, # Payload schema must match the target schema. + }] fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) - write_disposition = self.schema.get_table(self.schema.loads_table_name).get( - "write_disposition" - ) - upload_batch( - records, - db_client=self.db_client, - table_name=fq_loads_table_name, - write_disposition=write_disposition, - ) + write_disposition = self.schema.get_table(self.schema.loads_table_name).get("write_disposition") + upload_batch(records, db_client=self.db_client, table_name=fq_loads_table_name, write_disposition=write_disposition, ) def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") @@ -691,48 +458,22 @@ def restore_file_load(self, file_path: str) -> LoadJob: def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: parent_table = table.get("parent") - return LoadLanceDBJob( - self.schema, - table, - file_path, - type_mapper=self.type_mapper, - db_client=self.db_client, - client_config=self.config, - model_func=self.model_func, - fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None - ), - ) - - def create_table_chain_completed_followup_jobs( - self, - table_chain: Sequence[TTableSchema], - completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + return LoadLanceDBJob(self.schema, table, file_path, type_mapper=self.type_mapper, db_client=self.db_client, client_config=self.config, model_func=self.model_func, fq_table_name=self.make_qualified_table_name( + table["name"]), fq_parent_table_name=(self.make_qualified_table_name(parent_table) if parent_table else None), ) + + def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: assert completed_table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs - ) + jobs = super().create_table_chain_completed_followup_jobs(table_chain, completed_table_chain_jobs) for table in table_chain: if table in self.schema.dlt_tables(): continue # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition") == "merge": + if table.get("write_disposition")=="merge": parent_table = table.get("parent") - jobs.append( - LanceDBRemoveOrphansJob( - db_client=self.db_client, - table_schema=self.prepare_load_table(table["name"]), - fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None - ), - client_config=self.config, - ) - ) + jobs.append(LanceDBRemoveOrphansJob(db_client=self.db_client, table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name( + table["name"]), fq_parent_table_name=(self.make_qualified_table_name(parent_table) if parent_table else None), client_config=self.config, )) return jobs @@ -743,18 +484,8 @@ def table_exists(self, table_name: str) -> bool: class LoadLanceDBJob(LoadJob, FollowupJob): arrow_schema: TArrowSchema - def __init__( - self, - schema: Schema, - table_schema: TTableSchema, - local_path: str, - type_mapper: LanceDBTypeMapper, - db_client: DBConnection, - client_config: LanceDBClientConfiguration, - model_func: TextEmbeddingFunction, - fq_table_name: str, - fq_parent_table_name: Optional[str], - ) -> None: + def __init__(self, schema: Schema, table_schema: TTableSchema, local_path: str, type_mapper: LanceDBTypeMapper, db_client: DBConnection, client_config: LanceDBClientConfiguration, model_func: TextEmbeddingFunction, fq_table_name: str, fq_parent_table_name: + Optional[str], ) -> None: file_name = FileStorage.get_file_name_from_file_path(local_path) super().__init__(file_name) self.schema: Schema = schema @@ -769,9 +500,7 @@ def __init__( self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name - self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition", "append") - ) + self.write_disposition: TWriteDisposition = cast(TWriteDisposition, self.table_schema.get("write_disposition", "append")) with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f, memory_map=True) @@ -779,11 +508,7 @@ def __init__( if self.table_schema not in self.schema.dlt_tables(): for record in records: # Add reserved ID fields. - uuid_id = ( - generate_uuid(record, self.unique_identifiers, self.fq_table_name) - if self.unique_identifiers - else str(uuid.uuid4()) - ) + uuid_id = (generate_uuid(record, self.unique_identifiers, self.fq_table_name) if self.unique_identifiers else str(uuid.uuid4())) record.update({self.id_field_name: uuid_id}) # LanceDB expects all fields in the target arrow table to be present in the data payload. @@ -792,13 +517,7 @@ def __init__( for field in missing_fields: record[field] = None - upload_batch( - records, - db_client=db_client, - table_name=self.fq_table_name, - write_disposition=self.write_disposition, - id_field_name=self.id_field_name, - ) + upload_batch(records, db_client=db_client, table_name=self.fq_table_name, write_disposition=self.write_disposition, id_field_name=self.id_field_name, ) def state(self) -> TLoadJobState: return "completed" @@ -808,34 +527,17 @@ def exception(self) -> str: class LanceDBRemoveOrphansJob(NewLoadJobImpl): - def __init__( - self, - db_client: DBConnection, - table_schema: TTableSchema, - client_config: LanceDBClientConfiguration, - fq_table_name: str, - fq_parent_table_name: Optional[str], - ) -> None: + def __init__(self, db_client: DBConnection, table_schema: TTableSchema, client_config: LanceDBClientConfiguration, fq_table_name: str, fq_parent_table_name: Optional[str], ) -> None: self.db_client = db_client self.table_schema: TTableSchema = table_schema self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition") - ) + self.write_disposition: TWriteDisposition = cast(TWriteDisposition, self.table_schema.get("write_disposition")) self.id_field_name: str = client_config.id_field_name - job_id = ParsedLoadJobFileName( - table_schema["name"], - ParsedLoadJobFileName.new_file_id(), - 0, - "parquet", - ).file_name() + job_id = ParsedLoadJobFileName(table_schema["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet", ).file_name() - super().__init__( - file_name=job_id, - status="running", - ) + super().__init__(file_name=job_id, status="running", ) self._save_text_file("") @@ -844,11 +546,9 @@ def __init__( def execute(self) -> None: orphaned_ids: Set[str] - if self.write_disposition != "merge": - raise DestinationTerminalException( - f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" - " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." - ) + if self.write_disposition!="merge": + raise DestinationTerminalException(f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" + " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED.") # Orphans are removed irrespective of which merge strategy is picked. try: @@ -858,9 +558,7 @@ def execute(self) -> None: parent_table = self.db_client.open_table(self.fq_parent_table_name) parent_table.checkout_latest() except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" - ) from e + raise DestinationTransientException("Couldn't open lancedb database. Orphan removal WILL BE RETRIED") from e try: if self.fq_parent_table_name: @@ -871,7 +569,7 @@ def execute(self) -> None: if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: + elif len(orphaned_ids)==1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: @@ -879,31 +577,22 @@ def execute(self) -> None: # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. - grouping_key = ( - get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) - or self.id_field_name - ) + grouping_key = (get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) or self.id_field_name) grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] - child_table_arrow: pa.Table = child_table.to_arrow().select( - [*grouping_key, "_dlt_load_id", "_dlt_id"] - ) + child_table_arrow: pa.Table = child_table.to_arrow().select([*grouping_key, "_dlt_load_id", "_dlt_id"]) - grouped = child_table_arrow.group_by(grouping_key).aggregate( - [("_dlt_load_id", "max")] - ) + grouped = child_table_arrow.group_by(grouping_key).aggregate([("_dlt_load_id", "max")]) joined = child_table_arrow.join(grouped, keys=grouping_key) orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() if len(orphaned_ids) > 1: child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: + elif len(orphaned_ids)==1: child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") except ArrowInvalid as e: - raise DestinationTerminalException( - "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." - ) from e + raise DestinationTerminalException("Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED.") from e def state(self) -> TLoadJobState: return "completed" From dc20a55986e6866993577d2825504e5fd9eb7b63 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 4 Aug 2024 23:12:28 +0200 Subject: [PATCH 31/68] Remove recommended file size from LanceDB destination capabilities Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index 9acb82344b..beab6a4031 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -30,8 +30,6 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.decimal_precision = (38, 18) caps.timestamp_precision = 6 - caps.recommended_file_size = 128_000_000 - return caps @property From 6ed540b0b4fbc6281abc66b727658aeeb3701d06 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 4 Aug 2024 23:26:14 +0200 Subject: [PATCH 32/68] Update LanceDB client to use more efficient batch processing methods on loading for Load Jobs Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 537 ++++++++++++++---- dlt/destinations/impl/lancedb/utils.py | 36 +- 2 files changed, 447 insertions(+), 126 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 04f0e2cac0..1294155c00 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,6 +1,18 @@ -import uuid from types import TracebackType -from typing import (List, Any, cast, Union, Tuple, Iterable, Type, Optional, Dict, Sequence, TYPE_CHECKING, Set, ) +from typing import ( + List, + Any, + cast, + Union, + Tuple, + Iterable, + Type, + Optional, + Dict, + Sequence, + TYPE_CHECKING, + Set, +) import lancedb # type: ignore import pyarrow as pa @@ -15,23 +27,54 @@ from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.exceptions import (DestinationUndefinedEntity, DestinationTransientException, - DestinationTerminalException, ) -from dlt.common.destination.reference import (JobClientBase, WithStateSync, LoadJob, StorageSchemaInfo, StateInfo, - TLoadJobState, NewLoadJob, FollowupJob, ) +from dlt.common.destination.exceptions import ( + DestinationUndefinedEntity, + DestinationTransientException, + DestinationTerminalException, +) +from dlt.common.destination.reference import ( + JobClientBase, + WithStateSync, + LoadJob, + StorageSchemaInfo, + StateInfo, + TLoadJobState, + NewLoadJob, + FollowupJob, +) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.typing import (TColumnType, TTableFormat, TTableSchemaColumns, TWriteDisposition, ) +from dlt.common.schema.typing import ( + TColumnType, + TTableFormat, + TTableSchemaColumns, + TWriteDisposition, +) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny -from dlt.destinations.impl.lancedb.configuration import (LanceDBClientConfiguration, ) -from dlt.destinations.impl.lancedb.exceptions import (lancedb_error, ) -from dlt.destinations.impl.lancedb.lancedb_adapter import (VECTORIZE_HINT, DOCUMENT_ID_HINT, ) -from dlt.destinations.impl.lancedb.schema import (make_arrow_field_schema, make_arrow_table_schema, TArrowSchema, - NULL_SCHEMA, TArrowField, ) -from dlt.destinations.impl.lancedb.utils import (list_merge_identifiers, generate_uuid, - set_non_standard_providers_environment_variables, ) +from dlt.destinations.impl.lancedb.configuration import ( + LanceDBClientConfiguration, +) +from dlt.destinations.impl.lancedb.exceptions import ( + lancedb_error, +) +from dlt.destinations.impl.lancedb.lancedb_adapter import ( + VECTORIZE_HINT, + DOCUMENT_ID_HINT, +) +from dlt.destinations.impl.lancedb.schema import ( + make_arrow_field_schema, + make_arrow_table_schema, + TArrowSchema, + NULL_SCHEMA, + TArrowField, +) +from dlt.destinations.impl.lancedb.utils import ( + list_merge_identifiers, + set_non_standard_providers_environment_variables, + generate_arrow_uuid_column, +) from dlt.destinations.job_impl import EmptyLoadJob, NewLoadJobImpl from dlt.destinations.type_mapping import TypeMapper @@ -45,38 +88,80 @@ class LanceDBTypeMapper(TypeMapper): - sct_to_unbound_dbt = {"text": pa.string(), "double": pa.float64(), "bool": pa.bool_(), "bigint": pa.int64(), "binary": pa.binary(), "date": pa.date32(), "complex": pa.string(), } + sct_to_unbound_dbt = { + "text": pa.string(), + "double": pa.float64(), + "bool": pa.bool_(), + "bigint": pa.int64(), + "binary": pa.binary(), + "date": pa.date32(), + "complex": pa.string(), + } sct_to_dbt = {} - dbt_to_sct = {pa.string(): "text", pa.float64(): "double", pa.bool_(): "bool", pa.int64(): "bigint", pa.binary(): "binary", pa.date32(): "date", } - - def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> pa.Decimal128Type: + dbt_to_sct = { + pa.string(): "text", + pa.float64(): "double", + pa.bool_(): "bool", + pa.int64(): "bigint", + pa.binary(): "binary", + pa.date32(): "date", + } + + def to_db_decimal_type( + self, precision: Optional[int], scale: Optional[int] + ) -> pa.Decimal128Type: precision, scale = self.decimal_precision(precision, scale) return pa.decimal128(precision, scale) - def to_db_datetime_type(self, precision: Optional[int], table_format: TTableFormat = None) -> pa.TimestampType: + def to_db_datetime_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> pa.TimestampType: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.timestamp(unit, "UTC") - def to_db_time_type(self, precision: Optional[int], table_format: TTableFormat = None) -> pa.Time64Type: + def to_db_time_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> pa.Time64Type: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.time64(unit) - def from_db_type(self, db_type: pa.DataType, precision: Optional[int] = None, scale: Optional[int] = None, ) -> TColumnType: + def from_db_type( + self, + db_type: pa.DataType, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> TColumnType: if isinstance(db_type, pa.TimestampType): - return dict(data_type="timestamp", precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], scale=scale, ) + return dict( + data_type="timestamp", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) if isinstance(db_type, pa.Time64Type): - return dict(data_type="time", precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], scale=scale, ) + return dict( + data_type="time", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale)==self.capabilities.wei_precision: + if (precision, scale) == self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) -def upload_batch(records: List[DictStrAny], /, *, db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, ) -> None: +def write_to_db( + records: Union[pa.Table, List[DictStrAny]], + /, + *, + db_client: DBConnection, + table_name: str, + write_disposition: Optional[TWriteDisposition] = "append", + id_field_name: Optional[str] = None, +) -> None: """Inserts records into a LanceDB table with automatic embedding computation. Args: @@ -95,22 +180,30 @@ def upload_batch(records: List[DictStrAny], /, *, db_client: DBConnection, table tbl = db_client.open_table(table_name) tbl.checkout_latest() except FileNotFoundError as e: - raise DestinationTransientException("Couldn't open lancedb database. Batch WILL BE RETRIED") from e + raise DestinationTransientException( + "Couldn't open lancedb database. Batch WILL BE RETRIED" + ) from e try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition=="replace": + elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition=="merge": + elif write_disposition == "merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") - tbl.merge_insert(id_field_name).when_matched_update_all().when_not_matched_insert_all().execute(records) + tbl.merge_insert( + id_field_name + ).when_matched_update_all().when_not_matched_insert_all().execute(records) else: - raise DestinationTerminalException(f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" - " failed AND WILL **NOT** BE RETRIED.") + raise DestinationTerminalException( + f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" + " failed AND WILL **NOT** BE RETRIED." + ) except ArrowInvalid as e: - raise DestinationTerminalException("Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED.") from e + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e class LanceDBClient(JobClientBase, WithStateSync): @@ -118,10 +211,19 @@ class LanceDBClient(JobClientBase, WithStateSync): model_func: TextEmbeddingFunction - def __init__(self, schema: Schema, config: LanceDBClientConfiguration, capabilities: DestinationCapabilitiesContext, ) -> None: + def __init__( + self, + schema: Schema, + config: LanceDBClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: super().__init__(schema, config, capabilities) self.config: LanceDBClientConfiguration = config - self.db_client: DBConnection = lancedb.connect(uri=self.config.credentials.uri, api_key=self.config.credentials.api_key, read_consistency_interval=timedelta(0), ) + self.db_client: DBConnection = lancedb.connect( + uri=self.config.credentials.uri, + api_key=self.config.credentials.api_key, + read_consistency_interval=timedelta(0), + ) self.registry = EmbeddingFunctionRegistry.get_instance() self.type_mapper = LanceDBTypeMapper(self.capabilities) self.sentinel_table_name = config.sentinel_table_name @@ -131,14 +233,24 @@ def __init__(self, schema: Schema, config: LanceDBClientConfiguration, capabilit # LanceDB doesn't provide a standardized way to set API keys across providers. # Some use ENV variables and others allow passing api key as an argument. # To account for this, we set provider environment variable as well. - set_non_standard_providers_environment_variables(embedding_model_provider, self.config.credentials.embedding_model_provider_api_key, ) + set_non_standard_providers_environment_variables( + embedding_model_provider, + self.config.credentials.embedding_model_provider_api_key, + ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider=="openai": + if embedding_model_provider == "openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings - self.model_func = PatchedOpenAIEmbeddings(max_retries=self.config.options.max_retries, api_key=self.config.credentials.api_key, ) + self.model_func = PatchedOpenAIEmbeddings( + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) else: - self.model_func = self.registry.get(embedding_model_provider).create(name=self.config.embedding_model, max_retries=self.config.options.max_retries, api_key=self.config.credentials.api_key, ) + self.model_func = self.registry.get(embedding_model_provider).create( + name=self.config.embedding_model, + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) self.vector_field_name = self.config.vector_field_name self.id_field_name = self.config.id_field_name @@ -152,13 +264,20 @@ def sentinel_table(self) -> str: return self.make_qualified_table_name(self.sentinel_table_name) def make_qualified_table_name(self, table_name: str) -> str: - return (f"{self.dataset_name}{self.config.dataset_separator}{table_name}" if self.dataset_name else table_name) + return ( + f"{self.dataset_name}{self.config.dataset_separator}{table_name}" + if self.dataset_name + else table_name + ) def get_table_schema(self, table_name: str) -> TArrowSchema: schema_table: Table = self.db_client.open_table(table_name) schema_table.checkout_latest() schema = schema_table.schema - return cast(TArrowSchema, schema, ) + return cast( + TArrowSchema, + schema, + ) @lancedb_error def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: @@ -181,7 +300,11 @@ def delete_table(self, table_name: str) -> None: """ self.db_client.drop_table(table_name) - def query_table(self, table_name: str, query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, ) -> LanceQueryBuilder: + def query_table( + self, + table_name: str, + query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, + ) -> LanceQueryBuilder: """Query a LanceDB table. Args: @@ -200,11 +323,15 @@ def _get_table_names(self) -> List[str]: """Return all tables in the dataset, excluding the sentinel table.""" if self.dataset_name: prefix = f"{self.dataset_name}{self.config.dataset_separator}" - table_names = [table_name for table_name in self.db_client.table_names() if table_name.startswith(prefix)] + table_names = [ + table_name + for table_name in self.db_client.table_names() + if table_name.startswith(prefix) + ] else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name!=self.sentinel_table] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -230,7 +357,10 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: continue schema = self.get_table_schema(fq_table_name) self.db_client.drop_table(fq_table_name) - self.create_table(table_name=fq_table_name, schema=schema, ) + self.create_table( + table_name=fq_table_name, + schema=schema, + ) @lancedb_error def is_storage_initialized(self) -> bool: @@ -245,7 +375,11 @@ def _delete_sentinel_table(self) -> None: self.db_client.drop_table(self.sentinel_table) @lancedb_error - def update_stored_schema(self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None, ) -> Optional[TSchemaTables]: + def update_stored_schema( + self, + only_tables: Iterable[str] = None, + expected_update: TSchemaTables = None, + ) -> Optional[TSchemaTables]: super().update_stored_schema(only_tables, expected_update) applied_update: TSchemaTables = {} @@ -255,13 +389,17 @@ def update_stored_schema(self, only_tables: Iterable[str] = None, expected_updat schema_info = None if schema_info is None: - logger.info(f"Schema with hash {self.schema.stored_version_hash} " - "not found in the storage. upgrading") + logger.info( + f"Schema with hash {self.schema.stored_version_hash} " + "not found in the storage. upgrading" + ) self._execute_schema_update(only_tables) else: - logger.info(f"Schema with hash {self.schema.stored_version_hash} " - f"inserted at {schema_info.inserted_at} found " - "in storage, no upgrade required") + logger.info( + f"Schema with hash {self.schema.stored_version_hash} " + f"inserted at {schema_info.inserted_at} found " + "in storage, no upgrade required" + ) return applied_update def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: @@ -279,11 +417,16 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] field: TArrowField for field in arrow_schema: name = self.schema.naming.normalize_identifier(field.name) - table_schema[name] = {"name": name, **self.type_mapper.from_db_type(field.type), } + table_schema[name] = { + "name": name, + **self.type_mapper.from_db_type(field.type), + } return True, table_schema @lancedb_error - def add_table_fields(self, table_name: str, field_schemas: List[TArrowField]) -> Optional[Table]: + def add_table_fields( + self, table_name: str, field_schemas: List[TArrowField] + ) -> Optional[Table]: """Add multiple fields to the LanceDB table at once. Args: @@ -316,16 +459,25 @@ def add_table_fields(self, table_name: str, field_schemas: List[TArrowField]) -> def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns(table_name, existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) + new_columns = self.schema.get_new_table_columns( + table_name, + existing_columns, + self.capabilities.generates_case_sensitive_identifiers(), + ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") if len(new_columns) > 0: if exists: - field_schemas: List[TArrowField] = [make_arrow_field_schema(column["name"], column, self.type_mapper) for column in new_columns] + field_schemas: List[TArrowField] = [ + make_arrow_field_schema(column["name"], column, self.type_mapper) + for column in new_columns + ] fq_table_name = self.make_qualified_table_name(table_name) self.add_table_fields(fq_table_name, field_schemas) else: if table_name not in self.schema.dlt_table_names(): - embedding_fields = get_columns_names_with_prop(self.schema.get_table(table_name=table_name), VECTORIZE_HINT) + embedding_fields = get_columns_names_with_prop( + self.schema.get_table(table_name=table_name), VECTORIZE_HINT + ) vector_field_name = self.vector_field_name id_field_name = self.id_field_name embedding_model_func = self.model_func @@ -337,7 +489,16 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: embedding_model_func = None embedding_model_dimensions = None - table_schema: TArrowSchema = make_arrow_table_schema(table_name, schema=self.schema, type_mapper=self.type_mapper, embedding_fields=embedding_fields, embedding_model_func=embedding_model_func, embedding_model_dimensions=embedding_model_dimensions, vector_field_name=vector_field_name, id_field_name=id_field_name, ) + table_schema: TArrowSchema = make_arrow_table_schema( + table_name, + schema=self.schema, + type_mapper=self.type_mapper, + embedding_fields=embedding_fields, + embedding_model_func=embedding_model_func, + embedding_model_dimensions=embedding_model_dimensions, + vector_field_name=vector_field_name, + id_field_name=id_field_name, + ) fq_table_name = self.make_qualified_table_name(table_name) self.create_table(fq_table_name, table_schema) @@ -345,13 +506,33 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: @lancedb_error def update_schema_in_storage(self) -> None: - records = [{self.schema.naming.normalize_identifier("version"): self.schema.version, self.schema.naming.normalize_identifier("engine_version"): self.schema.ENGINE_VERSION, - self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), self.schema.naming.normalize_identifier("schema_name"): self.schema.name, - self.schema.naming.normalize_identifier("version_hash"): self.schema.stored_version_hash, self.schema.naming.normalize_identifier("schema"): json.dumps(self.schema.to_dict()), }] + records = [ + { + self.schema.naming.normalize_identifier("version"): self.schema.version, + self.schema.naming.normalize_identifier( + "engine_version" + ): self.schema.ENGINE_VERSION, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier( + "version_hash" + ): self.schema.stored_version_hash, + self.schema.naming.normalize_identifier("schema"): json.dumps( + self.schema.to_dict() + ), + } + ] fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - write_disposition = self.schema.get_table(self.schema.version_table_name).get("write_disposition") + write_disposition = self.schema.get_table(self.schema.version_table_name).get( + "write_disposition" + ) - upload_batch(records, db_client=self.db_client, table_name=fq_version_table_name, write_disposition=write_disposition, ) + write_to_db( + records, + db_client=self.db_client, + table_name=fq_version_table_name, + write_disposition=write_disposition, + ) @lancedb_error def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: @@ -378,18 +559,31 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: # Read the tables into memory as Arrow tables, with pushdown predicates, so we pull as less # data into memory as possible. - state_table = (state_table_.search().where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True).to_arrow()) + state_table = ( + state_table_.search() + .where(f"`{p_pipeline_name}` = '{pipeline_name}'", prefilter=True) + .to_arrow() + ) loads_table = loads_table_.search().where(f"`{p_status}` = 0", prefilter=True).to_arrow() # Join arrow tables in-memory. - joined_table: pa.Table = state_table.join(loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner").sort_by([(p_dlt_load_id, "descending")]) + joined_table: pa.Table = state_table.join( + loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" + ).sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows==0: + if joined_table.num_rows == 0: return None state = joined_table.take([0]).to_pylist()[0] - return StateInfo(version=state[p_version], engine_version=state[p_engine_version], pipeline_name=state[p_pipeline_name], state=state[p_state], created_at=pendulum.instance( - state[p_created_at]), version_hash=state[p_version_hash], _dlt_load_id=state[p_dlt_load_id], ) + return StateInfo( + version=state[p_version], + engine_version=state[p_engine_version], + pipeline_name=state[p_pipeline_name], + state=state[p_state], + created_at=pendulum.instance(state[p_created_at]), + version_hash=state[p_version_hash], + _dlt_load_id=state[p_dlt_load_id], + ) @lancedb_error def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: @@ -405,11 +599,21 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = (version_table.search().where(f'`{p_version_hash}` = "{schema_hash}"', prefilter=True)).to_list() + schemas = ( + version_table.search().where( + f'`{p_version_hash}` = "{schema_hash}"', prefilter=True + ) + ).to_list() most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] - return StorageSchemaInfo(version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], version=most_recent_schema[p_version], engine_version= - most_recent_schema[p_engine_version], inserted_at=most_recent_schema[p_inserted_at], schema=most_recent_schema[p_schema], ) + return StorageSchemaInfo( + version_hash=most_recent_schema[p_version_hash], + schema_name=most_recent_schema[p_schema_name], + version=most_recent_schema[p_version], + engine_version=most_recent_schema[p_engine_version], + inserted_at=most_recent_schema[p_inserted_at], + schema=most_recent_schema[p_schema], + ) except IndexError: return None @@ -428,15 +632,30 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: p_schema = self.schema.naming.normalize_identifier("schema") try: - schemas = (version_table.search().where(f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True)).to_list() + schemas = ( + version_table.search().where( + f'`{p_schema_name}` = "{self.schema.name}"', prefilter=True + ) + ).to_list() most_recent_schema = sorted(schemas, key=lambda x: x[p_inserted_at], reverse=True)[0] - return StorageSchemaInfo(version_hash=most_recent_schema[p_version_hash], schema_name=most_recent_schema[p_schema_name], version=most_recent_schema[p_version], engine_version= - most_recent_schema[p_engine_version], inserted_at=most_recent_schema[p_inserted_at], schema=most_recent_schema[p_schema], ) + return StorageSchemaInfo( + version_hash=most_recent_schema[p_version_hash], + schema_name=most_recent_schema[p_schema_name], + version=most_recent_schema[p_version], + engine_version=most_recent_schema[p_engine_version], + inserted_at=most_recent_schema[p_inserted_at], + schema=most_recent_schema[p_schema], + ) except IndexError: return None - def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType, ) -> None: + def __exit__( + self, + exc_type: Type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: pass def __enter__(self) -> "LanceDBClient": @@ -444,13 +663,27 @@ def __enter__(self) -> "LanceDBClient": @lancedb_error def complete_load(self, load_id: str) -> None: - records = [{self.schema.naming.normalize_identifier("load_id"): load_id, self.schema.naming.normalize_identifier("schema_name"): self.schema.name, - self.schema.naming.normalize_identifier("status"): 0, self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier("schema_version_hash"): None, # Payload schema must match the target schema. - }] + records = [ + { + self.schema.naming.normalize_identifier("load_id"): load_id, + self.schema.naming.normalize_identifier("schema_name"): self.schema.name, + self.schema.naming.normalize_identifier("status"): 0, + self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), + self.schema.naming.normalize_identifier( + "schema_version_hash" + ): None, # Payload schema must match the target schema. + } + ] fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) - write_disposition = self.schema.get_table(self.schema.loads_table_name).get("write_disposition") - upload_batch(records, db_client=self.db_client, table_name=fq_loads_table_name, write_disposition=write_disposition, ) + write_disposition = self.schema.get_table(self.schema.loads_table_name).get( + "write_disposition" + ) + write_to_db( + records, + db_client=self.db_client, + table_name=fq_loads_table_name, + write_disposition=write_disposition, + ) def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") @@ -458,22 +691,48 @@ def restore_file_load(self, file_path: str) -> LoadJob: def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: parent_table = table.get("parent") - return LoadLanceDBJob(self.schema, table, file_path, type_mapper=self.type_mapper, db_client=self.db_client, client_config=self.config, model_func=self.model_func, fq_table_name=self.make_qualified_table_name( - table["name"]), fq_parent_table_name=(self.make_qualified_table_name(parent_table) if parent_table else None), ) - - def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[NewLoadJob]: + return LoadLanceDBJob( + self.schema, + table, + file_path, + type_mapper=self.type_mapper, + db_client=self.db_client, + client_config=self.config, + model_func=self.model_func, + fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) if parent_table else None + ), + ) + + def create_table_chain_completed_followup_jobs( + self, + table_chain: Sequence[TTableSchema], + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + ) -> List[NewLoadJob]: assert completed_table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs(table_chain, completed_table_chain_jobs) + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) for table in table_chain: if table in self.schema.dlt_tables(): continue # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition")=="merge": + if table.get("write_disposition") == "merge": parent_table = table.get("parent") - jobs.append(LanceDBRemoveOrphansJob(db_client=self.db_client, table_schema=self.prepare_load_table(table["name"]), fq_table_name=self.make_qualified_table_name( - table["name"]), fq_parent_table_name=(self.make_qualified_table_name(parent_table) if parent_table else None), client_config=self.config, )) + jobs.append( + LanceDBRemoveOrphansJob( + db_client=self.db_client, + table_schema=self.prepare_load_table(table["name"]), + fq_table_name=self.make_qualified_table_name(table["name"]), + fq_parent_table_name=( + self.make_qualified_table_name(parent_table) if parent_table else None + ), + client_config=self.config, + ) + ) return jobs @@ -484,8 +743,18 @@ def table_exists(self, table_name: str) -> bool: class LoadLanceDBJob(LoadJob, FollowupJob): arrow_schema: TArrowSchema - def __init__(self, schema: Schema, table_schema: TTableSchema, local_path: str, type_mapper: LanceDBTypeMapper, db_client: DBConnection, client_config: LanceDBClientConfiguration, model_func: TextEmbeddingFunction, fq_table_name: str, fq_parent_table_name: - Optional[str], ) -> None: + def __init__( + self, + schema: Schema, + table_schema: TTableSchema, + local_path: str, + type_mapper: LanceDBTypeMapper, + db_client: DBConnection, + client_config: LanceDBClientConfiguration, + model_func: TextEmbeddingFunction, + fq_table_name: str, + fq_parent_table_name: Optional[str], + ) -> None: file_name = FileStorage.get_file_name_from_file_path(local_path) super().__init__(file_name) self.schema: Schema = schema @@ -500,24 +769,28 @@ def __init__(self, schema: Schema, table_schema: TTableSchema, local_path: str, self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions self.id_field_name: str = client_config.id_field_name - self.write_disposition: TWriteDisposition = cast(TWriteDisposition, self.table_schema.get("write_disposition", "append")) + self.write_disposition: TWriteDisposition = cast( + TWriteDisposition, self.table_schema.get("write_disposition", "append") + ) with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: - arrow_table: pa.Table = pq.read_table(f, memory_map=True) + arrow_table: pa.Table = pq.read_table(f) if self.table_schema not in self.schema.dlt_tables(): - for record in records: - # Add reserved ID fields. - uuid_id = (generate_uuid(record, self.unique_identifiers, self.fq_table_name) if self.unique_identifiers else str(uuid.uuid4())) - record.update({self.id_field_name: uuid_id}) - - # LanceDB expects all fields in the target arrow table to be present in the data payload. - # We add and set these missing fields, that are fields not present in the target schema, to NULL. - missing_fields = set(self.table_schema["columns"]) - set(record) - for field in missing_fields: - record[field] = None - - upload_batch(records, db_client=db_client, table_name=self.fq_table_name, write_disposition=self.write_disposition, id_field_name=self.id_field_name, ) + arrow_table = generate_arrow_uuid_column( + arrow_table, + unique_identifiers=self.unique_identifiers, + table_name=self.fq_table_name, + id_field_name=self.id_field_name, + ) + + write_to_db( + arrow_table, + db_client=db_client, + table_name=self.fq_table_name, + write_disposition=self.write_disposition, + id_field_name=self.id_field_name, + ) def state(self) -> TLoadJobState: return "completed" @@ -527,17 +800,34 @@ def exception(self) -> str: class LanceDBRemoveOrphansJob(NewLoadJobImpl): - def __init__(self, db_client: DBConnection, table_schema: TTableSchema, client_config: LanceDBClientConfiguration, fq_table_name: str, fq_parent_table_name: Optional[str], ) -> None: + def __init__( + self, + db_client: DBConnection, + table_schema: TTableSchema, + client_config: LanceDBClientConfiguration, + fq_table_name: str, + fq_parent_table_name: Optional[str], + ) -> None: self.db_client = db_client self.table_schema: TTableSchema = table_schema self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.write_disposition: TWriteDisposition = cast(TWriteDisposition, self.table_schema.get("write_disposition")) + self.write_disposition: TWriteDisposition = cast( + TWriteDisposition, self.table_schema.get("write_disposition") + ) self.id_field_name: str = client_config.id_field_name - job_id = ParsedLoadJobFileName(table_schema["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet", ).file_name() + job_id = ParsedLoadJobFileName( + table_schema["name"], + ParsedLoadJobFileName.new_file_id(), + 0, + "parquet", + ).file_name() - super().__init__(file_name=job_id, status="running", ) + super().__init__( + file_name=job_id, + status="running", + ) self._save_text_file("") @@ -546,9 +836,11 @@ def __init__(self, db_client: DBConnection, table_schema: TTableSchema, client_c def execute(self) -> None: orphaned_ids: Set[str] - if self.write_disposition!="merge": - raise DestinationTerminalException(f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" - " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED.") + if self.write_disposition != "merge": + raise DestinationTerminalException( + f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" + " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." + ) # Orphans are removed irrespective of which merge strategy is picked. try: @@ -558,7 +850,9 @@ def execute(self) -> None: parent_table = self.db_client.open_table(self.fq_parent_table_name) parent_table.checkout_latest() except FileNotFoundError as e: - raise DestinationTransientException("Couldn't open lancedb database. Orphan removal WILL BE RETRIED") from e + raise DestinationTransientException( + "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" + ) from e try: if self.fq_parent_table_name: @@ -569,7 +863,7 @@ def execute(self) -> None: if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids)==1: + elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: @@ -577,22 +871,31 @@ def execute(self) -> None: # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. - grouping_key = (get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) or self.id_field_name) + grouping_key = ( + get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) + or self.id_field_name + ) grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] - child_table_arrow: pa.Table = child_table.to_arrow().select([*grouping_key, "_dlt_load_id", "_dlt_id"]) + child_table_arrow: pa.Table = child_table.to_arrow().select( + [*grouping_key, "_dlt_load_id", "_dlt_id"] + ) - grouped = child_table_arrow.group_by(grouping_key).aggregate([("_dlt_load_id", "max")]) + grouped = child_table_arrow.group_by(grouping_key).aggregate( + [("_dlt_load_id", "max")] + ) joined = child_table_arrow.join(grouped, keys=grouping_key) orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() if len(orphaned_ids) > 1: child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids)==1: + elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") except ArrowInvalid as e: - raise DestinationTerminalException("Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED.") from e + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e def state(self) -> TLoadJobState: return "completed" diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index f202903598..0edb8487c5 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -2,9 +2,11 @@ import uuid from typing import Sequence, Union, Dict, List +import pyarrow as pa +import pyarrow.compute as pc + from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider @@ -16,19 +18,35 @@ } -def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: - """Generates deterministic UUID - used for deduplication. +# TODO: Update `generate_arrow_uuid_column` when pyarrow 17.0.0 becomes available with vectorized operations (batched + memory-mapped) +def generate_arrow_uuid_column( + table: pa.Table, unique_identifiers: List[str], id_field_name: str, table_name: str +) -> pa.Table: + """Generates deterministic UUID - used for deduplication, returning a new arrow + table with added UUID column. Args: - data (Dict[str, Any]): Arbitrary data to generate UUID for. - unique_identifiers (Sequence[str]): A list of unique identifiers. - table_name (str): LanceDB table name. + table (pa.Table): PyArrow table to generate UUIDs for. + unique_identifiers (List[str]): A list of unique identifier column names. + id_field_name (str): Name of the new UUID column. + table_name (str): Name of the table. Returns: - str: A string representation of the generated UUID. + pa.Table: New PyArrow table with the new UUID column. """ - data_id = "_".join(str(data[key]) for key in unique_identifiers) - return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) + + string_columns = [] + for col in unique_identifiers: + column = table[col] + column = pc.cast(column, pa.string()) + column = pc.fill_null(column, "") + string_columns.append(column.to_pylist()) + + concat_values = ["".join(x) for x in zip(*string_columns)] + uuids = [str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) for x in concat_values] + uuid_column = pa.array(uuids) + + return table.append_column(id_field_name, uuid_column) def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: From 0a9682f972f66bb7e45515d89e6765c001acf181 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 17:41:58 +0200 Subject: [PATCH 33/68] Refactor unique identifier handling for LanceDB tables Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 6 +++--- dlt/destinations/impl/lancedb/utils.py | 21 ++++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 1294155c00..e337936541 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -71,7 +71,7 @@ TArrowField, ) from dlt.destinations.impl.lancedb.utils import ( - list_merge_identifiers, + get_unique_identifiers_from_table_schema, set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) @@ -764,7 +764,7 @@ def __init__( self.table_name: str = table_schema["name"] self.fq_table_name: str = fq_table_name self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.unique_identifiers: List[str] = list_merge_identifiers(table_schema) + self.unique_identifiers: List[str] = get_unique_identifiers_from_table_schema(table_schema) self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func: TextEmbeddingFunction = model_func self.embedding_model_dimensions: int = client_config.embedding_model_dimensions @@ -776,7 +776,7 @@ def __init__( with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - if self.table_schema not in self.schema.dlt_tables(): + if self.table_schema['name'] not in self.schema.dlt_table_names(): arrow_table = generate_arrow_uuid_column( arrow_table, unique_identifiers=self.unique_identifiers, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 0edb8487c5..9e9d05c936 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -35,21 +35,20 @@ def generate_arrow_uuid_column( pa.Table: New PyArrow table with the new UUID column. """ - string_columns = [] + unique_identifiers_columns = [] for col in unique_identifiers: column = table[col] column = pc.cast(column, pa.string()) column = pc.fill_null(column, "") - string_columns.append(column.to_pylist()) + unique_identifiers_columns.append(column.to_pylist()) - concat_values = ["".join(x) for x in zip(*string_columns)] - uuids = [str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) for x in concat_values] + concatenated_ids = ["".join(x) for x in zip(*unique_identifiers_columns)] + uuids = [str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) for x in concatenated_ids] uuid_column = pa.array(uuids) - return table.append_column(id_field_name, uuid_column) -def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: +def get_unique_identifiers_from_table_schema(table_schema: TTableSchema) -> List[str]: """Returns a list of merge keys for a table used for either merging or deduplication. Args: @@ -58,12 +57,14 @@ def list_merge_identifiers(table_schema: TTableSchema) -> List[str]: Returns: Sequence[str]: A list of unique column identifiers. """ + primary_keys = get_columns_names_with_prop(table_schema, "primary_key") + merge_keys = [] if table_schema.get("write_disposition") == "merge": - primary_keys = get_columns_names_with_prop(table_schema, "primary_key") merge_keys = get_columns_names_with_prop(table_schema, "merge_key") - if join_keys := list(set(primary_keys + merge_keys)): - return join_keys - return get_columns_names_with_prop(table_schema, "unique") + if join_keys := list(set(primary_keys + merge_keys)): + return join_keys + else: + return get_columns_names_with_prop(table_schema, "unique") def set_non_standard_providers_environment_variables( From a99224a6a3b6315ae143aa68cccea3d47a97c9be Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 18:10:03 +0200 Subject: [PATCH 34/68] Optimize UUID column generation for LanceDB tables Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/utils.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 9e9d05c936..9091d2db1e 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -37,15 +37,17 @@ def generate_arrow_uuid_column( unique_identifiers_columns = [] for col in unique_identifiers: - column = table[col] - column = pc.cast(column, pa.string()) - column = pc.fill_null(column, "") + column = pc.fill_null(pc.cast(table[col], pa.string()), "") unique_identifiers_columns.append(column.to_pylist()) - concatenated_ids = ["".join(x) for x in zip(*unique_identifiers_columns)] - uuids = [str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) for x in concatenated_ids] - uuid_column = pa.array(uuids) - return table.append_column(id_field_name, uuid_column) + uuids = pa.array( + [ + str(uuid.uuid5(uuid.NAMESPACE_OID, x + table_name)) + for x in ["".join(x) for x in zip(*unique_identifiers_columns)] + ] + ) + + return table.append_column(id_field_name, uuids) def get_unique_identifiers_from_table_schema(table_schema: TTableSchema) -> List[str]: From 895331b81f73f4dc7c5112390a97af866734de3c Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 20:04:26 +0200 Subject: [PATCH 35/68] Refactor LanceDBClient to use string type hints for Table Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index e337936541..0ad8c9cdfb 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -21,9 +21,9 @@ from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore -from lancedb.table import Table # type: ignore +import lancedb.table # type: ignore from numpy import ndarray -from pyarrow import Array, ChunkedArray, ArrowInvalid, Table +from pyarrow import Array, ChunkedArray, ArrowInvalid from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext @@ -271,7 +271,7 @@ def make_qualified_table_name(self, table_name: str) -> str: ) def get_table_schema(self, table_name: str) -> TArrowSchema: - schema_table: Table = self.db_client.open_table(table_name) + schema_table: "lancedb.table.Table" = self.db_client.open_table(table_name) schema_table.checkout_latest() schema = schema_table.schema return cast( @@ -280,7 +280,9 @@ def get_table_schema(self, table_name: str) -> TArrowSchema: ) @lancedb_error - def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table: + def create_table( + self, table_name: str, schema: TArrowSchema, mode: str = "create" + ) -> "lancedb.table.Table": """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -314,7 +316,7 @@ def query_table( Returns: A LanceDB query builder. """ - query_table: Table = self.db_client.open_table(table_name) + query_table: "lancedb.table.Table" = self.db_client.open_table(table_name) query_table.checkout_latest() return query_table.search(query=query) @@ -366,7 +368,7 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: def is_storage_initialized(self) -> bool: return self.table_exists(self.sentinel_table) - def _create_sentinel_table(self) -> Table: + def _create_sentinel_table(self) -> "lancedb.table.Table": """Create an empty table to indicate that the storage is initialized.""" return self.create_table(schema=NULL_SCHEMA, table_name=self.sentinel_table) @@ -408,7 +410,7 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] try: fq_table_name = self.make_qualified_table_name(table_name) - table: Table = self.db_client.open_table(fq_table_name) + table: "lancedb.table.Table" = self.db_client.open_table(fq_table_name) table.checkout_latest() arrow_schema: TArrowSchema = table.schema except FileNotFoundError: @@ -425,15 +427,15 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] @lancedb_error def add_table_fields( - self, table_name: str, field_schemas: List[TArrowField] - ) -> Optional[Table]: + self, table_name: str, field_schemas: List[pa.Field] + ) -> Optional["lancedb.table.Table"]: """Add multiple fields to the LanceDB table at once. Args: - table_name: The name of the table to create the fields on. - field_schemas: The list of fields to create. + table_name: The name of the table to create the fields on. + field_schemas: The list of PyArrow Fields to create. """ - table: Table = self.db_client.open_table(table_name) + table: "lancedb.table.Table" = self.db_client.open_table(table_name) table.checkout_latest() arrow_table = table.to_arrow() @@ -540,10 +542,10 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name) fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) - state_table_: Table = self.db_client.open_table(fq_state_table_name) + state_table_: "lancedb.table.Table" = self.db_client.open_table(fq_state_table_name) state_table_.checkout_latest() - loads_table_: Table = self.db_client.open_table(fq_loads_table_name) + loads_table_: "lancedb.table.Table" = self.db_client.open_table(fq_loads_table_name) loads_table_.checkout_latest() # normalize property names @@ -589,7 +591,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]: fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table: "lancedb.table.Table" = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() p_version_hash = self.schema.naming.normalize_identifier("version_hash") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") @@ -622,7 +624,7 @@ def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name) - version_table: Table = self.db_client.open_table(fq_version_table_name) + version_table: "lancedb.table.Table" = self.db_client.open_table(fq_version_table_name) version_table.checkout_latest() p_version_hash = self.schema.naming.normalize_identifier("version_hash") p_inserted_at = self.schema.naming.normalize_identifier("inserted_at") @@ -776,7 +778,7 @@ def __init__( with FileStorage.open_zipsafe_ro(local_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - if self.table_schema['name'] not in self.schema.dlt_table_names(): + if self.table_schema["name"] not in self.schema.dlt_table_names(): arrow_table = generate_arrow_uuid_column( arrow_table, unique_identifiers=self.unique_identifiers, From a881e7a0571f4764a2212be251ce752d205dcd5d Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 20:07:59 +0200 Subject: [PATCH 36/68] Minor refactor Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 0ad8c9cdfb..278e01d0d1 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -48,7 +48,7 @@ TColumnType, TTableFormat, TTableSchemaColumns, - TWriteDisposition, + TWriteDisposition, TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName @@ -461,13 +461,13 @@ def add_table_fields( def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) - new_columns = self.schema.get_new_table_columns( + new_columns: List[TColumnSchema] = self.schema.get_new_table_columns( table_name, existing_columns, self.capabilities.generates_case_sensitive_identifiers(), ) logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}") - if len(new_columns) > 0: + if new_columns: if exists: field_schemas: List[TArrowField] = [ make_arrow_field_schema(column["name"], column, self.type_mapper) From 7f245e2534c03e23e08db359be45125880bb881a Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 5 Aug 2024 23:21:21 +0200 Subject: [PATCH 37/68] Implement efficient schema update with Nullability support Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 43 ++++++++++--------- dlt/destinations/impl/lancedb/schema.py | 21 ++++++++- dlt/destinations/impl/lancedb/utils.py | 2 + 3 files changed, 44 insertions(+), 22 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 278e01d0d1..9554d12a94 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -15,13 +15,13 @@ ) import lancedb # type: ignore +import lancedb.table # type: ignore import pyarrow as pa import pyarrow.compute as pc import pyarrow.parquet as pq from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore -import lancedb.table # type: ignore from numpy import ndarray from pyarrow import Array, ChunkedArray, ArrowInvalid @@ -48,7 +48,8 @@ TColumnType, TTableFormat, TTableSchemaColumns, - TWriteDisposition, TColumnSchema, + TWriteDisposition, + TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName @@ -69,6 +70,7 @@ TArrowSchema, NULL_SCHEMA, TArrowField, + arrow_datatype_to_fusion_datatype, ) from dlt.destinations.impl.lancedb.utils import ( get_unique_identifiers_from_table_schema, @@ -426,34 +428,33 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] return True, table_schema @lancedb_error - def add_table_fields( - self, table_name: str, field_schemas: List[pa.Field] - ) -> Optional["lancedb.table.Table"]: - """Add multiple fields to the LanceDB table at once. + def extend_lancedb_table_schema(self, table_name: str, field_schemas: List[pa.Field]) -> None: + """Extend LanceDB table schema with empty columns. Args: table_name: The name of the table to create the fields on. - field_schemas: The list of PyArrow Fields to create. + field_schemas: The list of PyArrow Fields to create in the target LanceDB table. """ table: "lancedb.table.Table" = self.db_client.open_table(table_name) table.checkout_latest() - arrow_table = table.to_arrow() - # Check if any of the new fields already exist in the table. - existing_fields = set(arrow_table.schema.names) - new_fields = [field for field in field_schemas if field.name not in existing_fields] - - if not new_fields: - # All fields already present, skip. - return None + try: + # Use DataFusion SQL syntax to alter fields without loading data into client memory. + # Currently, the most efficient way to modify column values is in LanceDB. + new_fields = { + field.name: f"CAST(NULL AS {arrow_datatype_to_fusion_datatype(field.type)})" + for field in field_schemas + } + table.add_columns(new_fields) - null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields] + # Make new columns nullable in the Arrow schema. + # Necessary because the Datafusion SQL API doesn't set new columns as nullable by default. + for field in field_schemas: + table.alter_columns({"path": field.name, "nullable": field.nullable}) - for field, null_array in zip(new_fields, null_arrays): - arrow_table = arrow_table.append_column(field, null_array) + # TODO: Update method below doesn't work for bulk NULL assignments, raise with LanceDB developers. + # table.update(values={field.name: None}) - try: - return self.db_client.create_table(table_name, arrow_table, mode="overwrite") except OSError: # Error occurred while creating the table, skip. return None @@ -474,7 +475,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for column in new_columns ] fq_table_name = self.make_qualified_table_name(table_name) - self.add_table_fields(fq_table_name, field_schemas) + self.extend_lancedb_table_schema(fq_table_name, field_schemas) else: if table_name not in self.schema.dlt_table_names(): embedding_fields = get_columns_names_with_prop( diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index c7cceec274..db624aeb12 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -1,6 +1,5 @@ """Utilities for creating arrow schemas from table schemas.""" -from dlt.common.json import json from typing import ( List, cast, @@ -11,6 +10,7 @@ from lancedb.embeddings import TextEmbeddingFunction # type: ignore from typing_extensions import TypeAlias +from dlt.common.json import json from dlt.common.schema import Schema, TColumnSchema from dlt.common.typing import DictStrAny from dlt.destinations.type_mapping import TypeMapper @@ -82,3 +82,22 @@ def make_arrow_table_schema( metadata["embedding_functions"] = json.dumps(embedding_functions).encode("utf-8") return pa.schema(arrow_schema, metadata=metadata) + + +def arrow_datatype_to_fusion_datatype(arrow_type: TArrowSchema) -> str: + type_map = { + pa.bool_(): "BOOLEAN", + pa.int64(): "BIGINT", + pa.float64(): "DOUBLE", + pa.utf8(): "STRING", + pa.binary(): "BYTEA", + pa.date32(): "DATE", + } + + if isinstance(arrow_type, pa.Decimal128Type): + return f"DECIMAL({arrow_type.precision}, {arrow_type.scale})" + + if isinstance(arrow_type, pa.TimestampType): + return "TIMESTAMP" + + return type_map.get(arrow_type, "UNKNOWN") diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 9091d2db1e..0e6d4744a7 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -74,3 +74,5 @@ def set_non_standard_providers_environment_variables( ) -> None: if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" + + From 4fc73dd04d520592b08e24417a3710305b1fc2bf Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 6 Aug 2024 00:31:14 +0200 Subject: [PATCH 38/68] Optimize orphaned chunks removal for large datasets Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 9554d12a94..a068842814 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -87,6 +87,7 @@ TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +BATCH_PROCESS_CHUNK_SIZE = 10_000 class LanceDBTypeMapper(TypeMapper): @@ -859,15 +860,25 @@ def execute(self) -> None: try: if self.fq_parent_table_name: - # Chunks and embeddings in child table. - parent_ids = set(pc.unique(parent_table.to_arrow()["_dlt_id"]).to_pylist()) - child_ids = set(pc.unique(child_table.to_arrow()["_dlt_parent_id"]).to_pylist()) - - if orphaned_ids := child_ids - parent_ids: - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") + parent_id_iter: "pa.RecordBatchReader" = ( + parent_table.to_lance().scanner(columns=["_dlt_id"]).to_reader() + ) + all_parent_ids = set() + + for batch in parent_id_iter: + chunk_ids = set(pc.unique(batch["_dlt_id"]).to_pylist()) + all_parent_ids.update(chunk_ids) + + # Delete it from db and clear memory. + if len(all_parent_ids) >= BATCH_PROCESS_CHUNK_SIZE: + delete_condition = f"_dlt_parent_id NOT IN {tuple(all_parent_ids)}" + child_table.delete(delete_condition) + all_parent_ids.clear() + + # Process any remaining IDs. + if all_parent_ids: + delete_condition = f"_dlt_parent_id NOT IN {tuple(all_parent_ids)}" + child_table.delete(delete_condition) else: # Chunks and embeddings in the root table. From 9378f505946c064c200719ad36dda4019a8e156b Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 6 Aug 2024 20:05:50 +0200 Subject: [PATCH 39/68] Projection pushdown Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a068842814..bc57842989 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -860,25 +860,26 @@ def execute(self) -> None: try: if self.fq_parent_table_name: - parent_id_iter: "pa.RecordBatchReader" = ( - parent_table.to_lance().scanner(columns=["_dlt_id"]).to_reader() + # Chunks and embeddings in child table. + # By referencing underlying lance dataset we benefit from projection push-down to storage layer (LanceDB). + parent_ids = set( + pc.unique( + parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] + ).to_pylist() + ) + child_ids = set( + pc.unique( + child_table.to_lance().to_table(columns=["_dlt_parent_id"])[ + "_dlt_parent_id" + ] + ).to_pylist() ) - all_parent_ids = set() - - for batch in parent_id_iter: - chunk_ids = set(pc.unique(batch["_dlt_id"]).to_pylist()) - all_parent_ids.update(chunk_ids) - - # Delete it from db and clear memory. - if len(all_parent_ids) >= BATCH_PROCESS_CHUNK_SIZE: - delete_condition = f"_dlt_parent_id NOT IN {tuple(all_parent_ids)}" - child_table.delete(delete_condition) - all_parent_ids.clear() - # Process any remaining IDs. - if all_parent_ids: - delete_condition = f"_dlt_parent_id NOT IN {tuple(all_parent_ids)}" - child_table.delete(delete_condition) + if orphaned_ids := child_ids - parent_ids: + if len(orphaned_ids) > 1: + child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: # Chunks and embeddings in the root table. @@ -890,8 +891,8 @@ def execute(self) -> None: or self.id_field_name ) grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] - child_table_arrow: pa.Table = child_table.to_arrow().select( - [*grouping_key, "_dlt_load_id", "_dlt_id"] + child_table_arrow: pa.Table = child_table.to_lance().to_table( + columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] ) grouped = child_table_arrow.group_by(grouping_key).aggregate( From 9b14583ad242a39d744ebdb26cbffe5c691d8ff1 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 6 Aug 2024 21:20:22 +0200 Subject: [PATCH 40/68] Format Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 0e6d4744a7..9091d2db1e 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -74,5 +74,3 @@ def set_non_standard_providers_environment_variables( ) -> None: if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" - - From e21f61bbb7ef804b36e12461c23df6a986b74219 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 6 Aug 2024 22:36:28 +0200 Subject: [PATCH 41/68] Prevent primary key and document ID hint conflict in merge disposition Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 15 ++++-- tests/load/lancedb/test_pipeline.py | 53 +++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index bc57842989..8d4c031947 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -42,6 +42,7 @@ NewLoadJob, FollowupJob, ) +from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import ( @@ -884,12 +885,18 @@ def execute(self) -> None: else: # Chunks and embeddings in the root table. + document_id_field = get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) + if document_id_field and get_columns_names_with_prop( + self.table_schema, "primary_key" + ): + raise SystemConfigurationException( + "You CANNOT specify a primary key AND a document ID hint for the same" + " resource when using merge disposition." + ) + # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. - grouping_key = ( - get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) - or self.id_field_name - ) + grouping_key = document_id_field or self.id_field_name grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] child_table_arrow: pa.Table = child_table.to_lance().to_table( columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index ea52e2f472..4b964604e6 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -16,6 +16,7 @@ ) from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient from dlt.extract import DltResource +from dlt.pipeline.exceptions import PipelineStepFailed from tests.load.lancedb.utils import assert_table, chunk_document, mock_embed from tests.load.utils import sequence_generator, drop_active_pipeline_data from tests.pipeline.utils import assert_load_info @@ -646,3 +647,55 @@ def documents_source( tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined] df = tbl.to_pandas() assert set(df["chunk_text"]) == expected_text + + +def test_primary_key_not_compatible_with_doc_id_hint_on_merge_disposition() -> None: + @dlt.resource( # type: ignore + write_disposition="merge", + table_name="document", + primary_key="doc_id", + columns={"doc_id": {DOCUMENT_ID_HINT: True}}, + ) + def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: + for doc in docs: + doc_id = doc["doc_id"] + chunks = chunk_document(doc["text"]) + embeddings = [ + { + "chunk_hash": digest128(chunk), + "chunk_text": chunk, + "embedding": mock_embed(), + } + for chunk in chunks + ] + yield {"doc_id": doc_id, "doc_text": doc["text"], "embeddings": embeddings} + + @dlt.source(max_table_nesting=1) + def documents_source( + docs: List[DictStrAny], + ) -> Union[Generator[Dict[str, Any], None, None], DltResource]: + return documents(docs) + + pipeline = dlt.pipeline( + pipeline_name="test_mandatory_doc_id_hint_on_merge_disposition", + destination="lancedb", + dataset_name="test_mandatory_doc_id_hint_on_merge_disposition", + dev_mode=True, + ) + + initial_docs = [ + { + "text": ( + "This is the first document. It contains some text that will be chunked and" + " embedded. (I don't want to be seen in updated run's embedding chunk texts btw)" + ), + "doc_id": 1, + }, + { + "text": "Here's another document. It's a bit different from the first one.", + "doc_id": 2, + }, + ] + + with pytest.raises(PipelineStepFailed): + pipeline.run(documents(initial_docs)) From 9725d0e9fe0c47ca1120c364d8c02f70a1042223 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 7 Aug 2024 15:57:38 +0200 Subject: [PATCH 42/68] Add recommended file size for LanceDB destination Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index beab6a4031..9acb82344b 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -30,6 +30,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.decimal_precision = (38, 18) caps.timestamp_precision = 6 + caps.recommended_file_size = 128_000_000 + return caps @property From 5238c1125afc234cd08a734b2454c8d0de4f795c Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Wed, 7 Aug 2024 16:04:07 +0200 Subject: [PATCH 43/68] Improve comment clarity for projection push-down in LanceDB Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8d4c031947..951c8c725d 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -862,7 +862,8 @@ def execute(self) -> None: try: if self.fq_parent_table_name: # Chunks and embeddings in child table. - # By referencing underlying lance dataset we benefit from projection push-down to storage layer (LanceDB). + + # By referencing the underlying lance dataset, we benefit from projection push-down to the storage layer (LanceDB). parent_ids = set( pc.unique( parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] From c8f74680db23deec7219c88223414c7d4af15a48 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 8 Aug 2024 01:09:41 +0200 Subject: [PATCH 44/68] Update to new load interface Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 45 +++++++++---------- dlt/destinations/impl/lancedb/utils.py | 2 +- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 86f4ce5854..7c69f5db8a 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -38,8 +38,6 @@ RunnableLoadJob, StorageSchemaInfo, StateInfo, - TLoadJobState, - NewLoadJob, FollowupJob, LoadJob, ) @@ -79,8 +77,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs -from dlt.destinations.job_impl import EmptyLoadJob, NewLoadJobImpl +from dlt.destinations.job_impl import FollowupJobImpl from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -153,7 +150,7 @@ def from_db_type( ) if isinstance(db_type, pa.Decimal128Type): precision, scale = db_type.precision, db_type.scale - if (precision, scale)==self.capabilities.wei_precision: + if (precision, scale) == self.capabilities.wei_precision: return cast(TColumnType, dict(data_type="wei")) return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(cast(str, db_type), precision, scale) @@ -193,9 +190,9 @@ def write_to_db( try: if write_disposition in ("append", "skip"): tbl.add(records) - elif write_disposition=="replace": + elif write_disposition == "replace": tbl.add(records, mode="overwrite") - elif write_disposition=="merge": + elif write_disposition == "merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") tbl.merge_insert( @@ -244,7 +241,7 @@ def __init__( self.config.credentials.embedding_model_provider_api_key, ) # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider=="openai": + if embedding_model_provider == "openai": from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings self.model_func = PatchedOpenAIEmbeddings( @@ -339,7 +336,7 @@ def _get_table_names(self) -> List[str]: else: table_names = self.db_client.table_names() - return [table_name for table_name in table_names if table_name!=self.sentinel_table] + return [table_name for table_name in table_names if table_name != self.sentinel_table] @lancedb_error def drop_storage(self) -> None: @@ -578,7 +575,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner" ).sort_by([(p_dlt_load_id, "descending")]) - if joined_table.num_rows==0: + if joined_table.num_rows == 0: return None state = joined_table.take([0]).to_pylist()[0] @@ -711,7 +708,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: assert completed_table_chain_jobs is not None jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs @@ -722,7 +719,7 @@ def create_table_chain_completed_followup_jobs( continue # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition")=="merge": + if table.get("write_disposition") == "merge": parent_table = table.get("parent") jobs.append( LanceDBRemoveOrphansJob( @@ -742,7 +739,7 @@ def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LanceDBLoadJob(RunnableLoadJob, FollowupJob): +class LanceDBLoadJob(RunnableLoadJob): arrow_schema: TArrowSchema def __init__( @@ -765,7 +762,9 @@ def run(self) -> None: self._embedding_model_dimensions: int = self._job_client.config.embedding_model_dimensions self._id_field_name: str = self._job_client.config.id_field_name - unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema(self._load_table) + unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( + self._load_table + ) write_disposition: TWriteDisposition = cast( TWriteDisposition, self._load_table.get("write_disposition", "append") ) @@ -776,9 +775,9 @@ def run(self) -> None: if self._load_table not in self._schema.dlt_tables(): arrow_table = generate_arrow_uuid_column( arrow_table, - unique_identifiers=self.unique_identifiers, - table_name=self.fq_table_name, - id_field_name=self.id_field_name, + unique_identifiers=unique_identifiers, + table_name=self._fq_table_name, + id_field_name=self._id_field_name, ) write_to_db( @@ -790,7 +789,7 @@ def run(self) -> None: ) -class LanceDBRemoveOrphansJob(NewLoadJobImpl): +class LanceDBRemoveOrphansJob(FollowupJobImpl): def __init__( self, db_client: DBConnection, @@ -817,7 +816,6 @@ def __init__( super().__init__( file_name=job_id, - status="running", ) self._save_text_file("") @@ -827,7 +825,7 @@ def __init__( def execute(self) -> None: orphaned_ids: Set[str] - if self.write_disposition!="merge": + if self.write_disposition != "merge": raise DestinationTerminalException( f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." @@ -866,7 +864,7 @@ def execute(self) -> None: if orphaned_ids := child_ids - parent_ids: if len(orphaned_ids) > 1: child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids)==1: + elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") else: @@ -898,13 +896,10 @@ def execute(self) -> None: if len(orphaned_ids) > 1: child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids)==1: + elif len(orphaned_ids) == 1: child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") except ArrowInvalid as e: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e - - def state(self) -> TLoadJobState: - return "completed" diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 9091d2db1e..466111dc98 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -20,7 +20,7 @@ # TODO: Update `generate_arrow_uuid_column` when pyarrow 17.0.0 becomes available with vectorized operations (batched + memory-mapped) def generate_arrow_uuid_column( - table: pa.Table, unique_identifiers: List[str], id_field_name: str, table_name: str + table: pa.Table, unique_identifiers: Sequence[str], id_field_name: str, table_name: str ) -> pa.Table: """Generates deterministic UUID - used for deduplication, returning a new arrow table with added UUID column. From af561918ccf01ac6e0b08459456cd698e8a92309 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 8 Aug 2024 23:00:03 +0200 Subject: [PATCH 45/68] Remove unnecessary LanceDBLoadJob attributes Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 86 +++++++++++-------- dlt/destinations/sql_jobs.py | 1 - 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 7c69f5db8a..c16f67e702 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -19,6 +19,7 @@ import pyarrow as pa import pyarrow.compute as pc import pyarrow.parquet as pq +import yaml from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore @@ -692,16 +693,9 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - parent_table = table.get("parent") - return LanceDBLoadJob( file_path=file_path, - type_mapper=self.type_mapper, - model_func=self.model_func, fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None - ), ) def create_table_chain_completed_followup_jobs( @@ -714,25 +708,25 @@ def create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) - for table in table_chain: - if table in self.schema.dlt_tables(): - continue - - # Only tables with merge disposition are dispatched for orphan removal jobs. - if table.get("write_disposition") == "merge": - parent_table = table.get("parent") - jobs.append( - LanceDBRemoveOrphansJob( - db_client=self.db_client, - table_schema=self.prepare_load_table(table["name"]), - fq_table_name=self.make_qualified_table_name(table["name"]), - fq_parent_table_name=( - self.make_qualified_table_name(parent_table) if parent_table else None - ), - client_config=self.config, - ) - ) - + # for table in table_chain: + # if table in self.schema.dlt_tables(): + # continue + # + # # Only tables with merge disposition are dispatched for orphan removal jobs. + # if table.get("write_disposition") == "merge": + # parent_table = table.get("parent") + # jobs.append( + # LanceDBRemoveOrphansJob( + # db_client=self.db_client, + # table_schema=self.prepare_load_table(table["name"]), + # fq_table_name=self.make_qualified_table_name(table["name"]), + # fq_parent_table_name=( + # self.make_qualified_table_name(parent_table) if parent_table else None + # ), + # client_config=self.config, + # ) + # ) + # return jobs def table_exists(self, table_name: str) -> bool: @@ -745,21 +739,14 @@ class LanceDBLoadJob(RunnableLoadJob): def __init__( self, file_path: str, - type_mapper: LanceDBTypeMapper, - model_func: TextEmbeddingFunction, fq_table_name: str, - fq_parent_table_name: Optional[str], ) -> None: super().__init__(file_path) - self._type_mapper: TypeMapper = type_mapper self._fq_table_name: str = fq_table_name - self._model_func = model_func self._job_client: "LanceDBClient" = None def run(self) -> None: self._db_client: DBConnection = self._job_client.db_client - self._embedding_model_func: TextEmbeddingFunction = self._model_func - self._embedding_model_dimensions: int = self._job_client.config.embedding_model_dimensions self._id_field_name: str = self._job_client.config.id_field_name unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( @@ -822,7 +809,7 @@ def __init__( self.execute() - def execute(self) -> None: + def run(self) -> None: orphaned_ids: Set[str] if self.write_disposition != "merge": @@ -903,3 +890,34 @@ def execute(self) -> None: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e + + @classmethod + def from_table_chain( + cls, + table_chain: Sequence[TTableSchema], + ) -> FollowupJobImpl: + """Generates a list of orphan removal tasks that the client will execute when the job is executed in the loader. + + The `table_chain` contains a listo of table schemas with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). + """ + top_table = table_chain[0] + file_info = ParsedLoadJobFileName( + top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet" + ) + try: + job = cls(file_info.file_name()) + # Write parquet file contents?? + except Exception as e: + raise LanceDBJobCreationException(e, table_chain) from e + return job + + +class LanceDBJobCreationException(DestinationTransientException): + def __init__(self, original_exception: Exception, table_chain: Sequence[TTableSchema]) -> None: + tables_chain = yaml.dump( + table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False + ) + super().__init__( + f"Could not create SQLFollowupJob with exception {str(original_exception)}. Table" + f" chain: {tables_chain}" + ) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index cddae52bb7..643264968a 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -75,7 +75,6 @@ def from_table_chain( job = cls(file_info.file_name()) job._save_text_file("\n".join(sql)) except Exception as e: - # raise exception with some context raise SqlJobCreationException(e, table_chain) from e return job From 7e33011f52f35759b0a1ec1fe50906daf4762efc Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 9 Aug 2024 00:31:29 +0200 Subject: [PATCH 46/68] Change instance attributes to `run` method as variables Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 126 ++++++------------ 1 file changed, 40 insertions(+), 86 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index c16f67e702..9ec9288d6b 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -78,7 +78,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) -from dlt.destinations.job_impl import FollowupJobImpl +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -693,41 +693,17 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - return LanceDBLoadJob( - file_path=file_path, - fq_table_name=self.make_qualified_table_name(table["name"]), - ) + if ReferenceFollowupJob.is_reference_job(file_path): + return LanceDBRemoveOrphansJob(file_path, table) + else: + return LanceDBLoadJob(file_path, table) def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - assert completed_table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs - ) - - # for table in table_chain: - # if table in self.schema.dlt_tables(): - # continue - # - # # Only tables with merge disposition are dispatched for orphan removal jobs. - # if table.get("write_disposition") == "merge": - # parent_table = table.get("parent") - # jobs.append( - # LanceDBRemoveOrphansJob( - # db_client=self.db_client, - # table_schema=self.prepare_load_table(table["name"]), - # fq_table_name=self.make_qualified_table_name(table["name"]), - # fq_parent_table_name=( - # self.make_qualified_table_name(parent_table) if parent_table else None - # ), - # client_config=self.config, - # ) - # ) - # - return jobs + return [LanceDBRemoveOrphansJob.from_table_chain(table_chain)] def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -739,16 +715,16 @@ class LanceDBLoadJob(RunnableLoadJob): def __init__( self, file_path: str, - fq_table_name: str, + table_schema: TTableSchema, ) -> None: super().__init__(file_path) - self._fq_table_name: str = fq_table_name self._job_client: "LanceDBClient" = None + self._table_schema: TTableSchema = table_schema def run(self) -> None: - self._db_client: DBConnection = self._job_client.db_client - self._id_field_name: str = self._job_client.config.id_field_name - + db_client: DBConnection = self._job_client.db_client + fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) + id_field_name: str = self._job_client.config.id_field_name unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( self._load_table ) @@ -763,67 +739,44 @@ def run(self) -> None: arrow_table = generate_arrow_uuid_column( arrow_table, unique_identifiers=unique_identifiers, - table_name=self._fq_table_name, - id_field_name=self._id_field_name, + table_name=fq_table_name, + id_field_name=id_field_name, ) write_to_db( arrow_table, - db_client=self._db_client, - table_name=self._fq_table_name, + db_client=db_client, + table_name=fq_table_name, write_disposition=write_disposition, - id_field_name=self._id_field_name, + id_field_name=id_field_name, ) -class LanceDBRemoveOrphansJob(FollowupJobImpl): +class LanceDBRemoveOrphansJob(RunnableLoadJob): + orphaned_ids: Set[str] + def __init__( self, - db_client: DBConnection, + file_path: str, table_schema: TTableSchema, - client_config: LanceDBClientConfiguration, - fq_table_name: str, - fq_parent_table_name: Optional[str], ) -> None: - self.db_client = db_client - self.table_schema: TTableSchema = table_schema - self.fq_table_name: str = fq_table_name - self.fq_parent_table_name: Optional[str] = fq_parent_table_name - self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition") - ) - self.id_field_name: str = client_config.id_field_name - - job_id = ParsedLoadJobFileName( - table_schema["name"], - ParsedLoadJobFileName.new_file_id(), - 0, - "parquet", - ).file_name() - - super().__init__( - file_name=job_id, - ) - - self._save_text_file("") - - self.execute() + super().__init__(file_path) + self._job_client: "LanceDBClient" = None + self._table_schema: TTableSchema = table_schema def run(self) -> None: - orphaned_ids: Set[str] - - if self.write_disposition != "merge": - raise DestinationTerminalException( - f"Unsupported write disposition {self.write_disposition} for LanceDB Destination" - " Orphan Removal Job - failed AND WILL **NOT** BE RETRIED." - ) + db_client: DBConnection = self._job_client.db_client + fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) + fq_parent_table_name: str = self._job_client.make_qualified_table_name( + self._table_schema["parent"] + ) + id_field_name: str = self._job_client.config.id_field_name - # Orphans are removed irrespective of which merge strategy is picked. try: - child_table = self.db_client.open_table(self.fq_table_name) + child_table = db_client.open_table(fq_table_name) child_table.checkout_latest() - if self.fq_parent_table_name: - parent_table = self.db_client.open_table(self.fq_parent_table_name) + if fq_parent_table_name: + parent_table = db_client.open_table(fq_parent_table_name) parent_table.checkout_latest() except FileNotFoundError as e: raise DestinationTransientException( @@ -831,10 +784,9 @@ def run(self) -> None: ) from e try: - if self.fq_parent_table_name: + if fq_parent_table_name: # Chunks and embeddings in child table. - # By referencing the underlying lance dataset, we benefit from projection push-down to the storage layer (LanceDB). parent_ids = set( pc.unique( parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] @@ -857,9 +809,11 @@ def run(self) -> None: else: # Chunks and embeddings in the root table. - document_id_field = get_columns_names_with_prop(self.table_schema, DOCUMENT_ID_HINT) + document_id_field = get_columns_names_with_prop( + self._table_schema, DOCUMENT_ID_HINT + ) if document_id_field and get_columns_names_with_prop( - self.table_schema, "primary_key" + self._table_schema, "primary_key" ): raise SystemConfigurationException( "You CANNOT specify a primary key AND a document ID hint for the same" @@ -868,7 +822,7 @@ def run(self) -> None: # If document ID is defined, we use this as the sole grouping key to identify stale chunks, # else fallback to the compound `id_field_name`. - grouping_key = document_id_field or self.id_field_name + grouping_key = document_id_field or id_field_name grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] child_table_arrow: pa.Table = child_table.to_lance().to_table( columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] @@ -895,10 +849,10 @@ def run(self) -> None: def from_table_chain( cls, table_chain: Sequence[TTableSchema], - ) -> FollowupJobImpl: + ) -> "LanceDBRemoveOrphansJob": """Generates a list of orphan removal tasks that the client will execute when the job is executed in the loader. - The `table_chain` contains a listo of table schemas with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). + The `table_chain` contains a list of table schemas with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). """ top_table = table_chain[0] file_info = ParsedLoadJobFileName( From ee7dd0260c65ae44b664744e9174e82c7b53f8e6 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Fri, 9 Aug 2024 23:36:19 +0200 Subject: [PATCH 47/68] Schedule follow up refernce job Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 50 +++++-------------- 1 file changed, 13 insertions(+), 37 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 9ec9288d6b..64a8386eea 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -19,7 +19,6 @@ import pyarrow as pa import pyarrow.compute as pc import pyarrow.parquet as pq -import yaml from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore @@ -41,6 +40,8 @@ StateInfo, FollowupJob, LoadJob, + HasFollowupJobs, + TLoadJobState, ) from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta @@ -53,7 +54,7 @@ TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName +from dlt.common.storages import FileStorage, LoadJobInfo from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -696,14 +697,17 @@ def create_load_job( if ReferenceFollowupJob.is_reference_job(file_path): return LanceDBRemoveOrphansJob(file_path, table) else: - return LanceDBLoadJob(file_path, table) + return LanceDBLoadJobWithFollowup(file_path, table) def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - return [LanceDBRemoveOrphansJob.from_table_chain(table_chain)] + table_job_paths = [job.file_path for job in completed_table_chain_jobs] + file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) + job = ReferenceFollowupJob(file_name, table_job_paths) + return [job] def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -752,6 +756,11 @@ def run(self) -> None: ) +class LanceDBLoadJobWithFollowup(HasFollowupJobs, LanceDBLoadJob): + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: + return super().create_followup_jobs(final_state) + + class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] @@ -786,7 +795,6 @@ def run(self) -> None: try: if fq_parent_table_name: # Chunks and embeddings in child table. - parent_ids = set( pc.unique( parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] @@ -808,7 +816,6 @@ def run(self) -> None: else: # Chunks and embeddings in the root table. - document_id_field = get_columns_names_with_prop( self._table_schema, DOCUMENT_ID_HINT ) @@ -844,34 +851,3 @@ def run(self) -> None: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e - - @classmethod - def from_table_chain( - cls, - table_chain: Sequence[TTableSchema], - ) -> "LanceDBRemoveOrphansJob": - """Generates a list of orphan removal tasks that the client will execute when the job is executed in the loader. - - The `table_chain` contains a list of table schemas with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). - """ - top_table = table_chain[0] - file_info = ParsedLoadJobFileName( - top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet" - ) - try: - job = cls(file_info.file_name()) - # Write parquet file contents?? - except Exception as e: - raise LanceDBJobCreationException(e, table_chain) from e - return job - - -class LanceDBJobCreationException(DestinationTransientException): - def __init__(self, original_exception: Exception, table_chain: Sequence[TTableSchema]) -> None: - tables_chain = yaml.dump( - table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False - ) - super().__init__( - f"Could not create SQLFollowupJob with exception {str(original_exception)}. Table" - f" chain: {tables_chain}" - ) From df498abf78df8d137c28142d02ee63e912f77525 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 11:08:41 +0200 Subject: [PATCH 48/68] Add follow up lancedb remove orphan job skeleron Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 64a8386eea..23cb6cd542 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -41,7 +41,6 @@ FollowupJob, LoadJob, HasFollowupJobs, - TLoadJobState, ) from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta @@ -54,7 +53,7 @@ TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage, LoadJobInfo +from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -79,7 +78,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import FollowupJobImpl from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -694,7 +693,7 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if ReferenceFollowupJob.is_reference_job(file_path): + if file_path.endswith(".remove_orphans"): return LanceDBRemoveOrphansJob(file_path, table) else: return LanceDBLoadJobWithFollowup(file_path, table) @@ -704,10 +703,22 @@ def create_table_chain_completed_followup_jobs( table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - table_job_paths = [job.file_path for job in completed_table_chain_jobs] - file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) - job = ReferenceFollowupJob(file_name, table_job_paths) - return [job] + assert completed_table_chain_jobs is not None + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) + for table in table_chain: + if ( + table.get("write_disposition") == "merge" + and table["name"] not in self.schema.dlt_table_names() + ): + file_name = repr( + ParsedLoadJobFileName( + table["name"], ParsedLoadJobFileName.new_file_id(), 0, "remove_orphans" + ) + ) + jobs.append(FollowupJobImpl(file_name)) + return jobs def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -757,8 +768,7 @@ def run(self) -> None: class LanceDBLoadJobWithFollowup(HasFollowupJobs, LanceDBLoadJob): - def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: - return super().create_followup_jobs(final_state) + pass class LanceDBRemoveOrphansJob(RunnableLoadJob): @@ -851,3 +861,8 @@ def run(self) -> None: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e + + +class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): + def __init__(self, file_name: str) -> None: + super().__init__(file_name) From c08f1ba948a88eba759162da48e527e1f9c36cb4 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 11:33:07 +0200 Subject: [PATCH 49/68] Write empty follow up file Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 23cb6cd542..180750d8ad 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,3 +1,4 @@ +import os from types import TracebackType from typing import ( List, @@ -693,7 +694,7 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if file_path.endswith(".remove_orphans"): + if LanceDBRemoveOrphansFollowupJob.is_remove_orphan_job(file_path): return LanceDBRemoveOrphansJob(file_path, table) else: return LanceDBLoadJobWithFollowup(file_path, table) @@ -703,22 +704,7 @@ def create_table_chain_completed_followup_jobs( table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - assert completed_table_chain_jobs is not None - jobs = super().create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs - ) - for table in table_chain: - if ( - table.get("write_disposition") == "merge" - and table["name"] not in self.schema.dlt_table_names() - ): - file_name = repr( - ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "remove_orphans" - ) - ) - jobs.append(FollowupJobImpl(file_name)) - return jobs + return LanceDBRemoveOrphansFollowupJob.from_table_chain(table_chain) def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -866,3 +852,22 @@ def run(self) -> None: class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): def __init__(self, file_name: str) -> None: super().__init__(file_name) + self._save_text_file("") + + @classmethod + def from_table_chain( + cls, + table_chain: Sequence[TTableSchema], + ) -> List[FollowupJobImpl]: + jobs = [] + for table in table_chain: + if table.get("write_disposition") == "merge": + file_info = ParsedLoadJobFileName( + table["name"], ParsedLoadJobFileName.new_file_id(), 0, "remove_orphans" + ) + jobs.append(cls(file_info.file_name())) + return jobs + + @staticmethod + def is_remove_orphan_job(file_path: str) -> bool: + return os.path.basename(file_path) == ".remove_orphans" From f9f94e3d5024dc7ca194758140ad3b62bd4d9597 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 15:22:41 +0200 Subject: [PATCH 50/68] Write parquet Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 180750d8ad..57aa956286 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -852,7 +852,10 @@ def run(self) -> None: class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): def __init__(self, file_name: str) -> None: super().__init__(file_name) - self._save_text_file("") + self._write_empty_parquet_file() + + def _write_empty_parquet_file(self): + pq.write_table(pa.table({}), self._file_path) @classmethod def from_table_chain( @@ -862,12 +865,14 @@ def from_table_chain( jobs = [] for table in table_chain: if table.get("write_disposition") == "merge": + # TODO: insert Identify orphan IDs into load job file here. Removal should then happen through `write_to_db` in orphan removal job. file_info = ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "remove_orphans" + table["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet" ) jobs.append(cls(file_info.file_name())) return jobs @staticmethod def is_remove_orphan_job(file_path: str) -> bool: - return os.path.basename(file_path) == ".remove_orphans" + return os.path.basename(file_path) == ".parquet" + From b374b0b73800682797db48bdb3fccf2ccd45799c Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 16:03:39 +0200 Subject: [PATCH 51/68] Add support for reference file format in LanceDB destination Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/factory.py | 2 +- dlt/destinations/impl/lancedb/lancedb_client.py | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index 9acb82344b..d9b92e02b9 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -17,7 +17,7 @@ class lancedb(Destination[LanceDBClientConfiguration, "LanceDBClient"]): def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "parquet" - caps.supported_loader_file_formats = ["parquet"] + caps.supported_loader_file_formats = ["parquet", "reference"] caps.max_identifier_length = 200 caps.max_column_identifier_length = 1024 diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 57aa956286..94d62fc8f1 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -852,10 +852,7 @@ def run(self) -> None: class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): def __init__(self, file_name: str) -> None: super().__init__(file_name) - self._write_empty_parquet_file() - - def _write_empty_parquet_file(self): - pq.write_table(pa.table({}), self._file_path) + self._save_text_file("") @classmethod def from_table_chain( @@ -865,14 +862,12 @@ def from_table_chain( jobs = [] for table in table_chain: if table.get("write_disposition") == "merge": - # TODO: insert Identify orphan IDs into load job file here. Removal should then happen through `write_to_db` in orphan removal job. file_info = ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "parquet" + table["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" ) jobs.append(cls(file_info.file_name())) return jobs @staticmethod def is_remove_orphan_job(file_path: str) -> bool: - return os.path.basename(file_path) == ".parquet" - + return os.path.splitext(file_path)[1][1:] == "reference" From 2ed3301d094dc49f3786e95876e135bf7aacaa96 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 10 Aug 2024 19:30:56 +0200 Subject: [PATCH 52/68] Handle parent table name resolution if it doesn't exist in Lance db remove orphan job Signed-off-by: Marcel Coetzee --- dlt/common/schema/exceptions.py | 9 ++++++--- dlt/common/schema/utils.py | 4 +++- dlt/destinations/impl/lancedb/lancedb_client.py | 17 +++++++++-------- dlt/normalize/schema.py | 5 ++++- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 1055163942..2e75b4b3a1 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -246,12 +246,15 @@ def __init__(self, schema_name: str, table_name: str, column: TColumnSchemaBase) elif column.get("primary_key"): key_type = "primary key" - msg = f"The column {column['name']} in table {table_name} did not receive any data during this load. " + msg = ( + f"The column {column['name']} in table {table_name} did not receive any data during" + " this load. " + ) if key_type or not nullable: msg += f"It is marked as non-nullable{' '+key_type} and it must have values. " msg += ( - "This can happen if you specify the column manually, for example using the 'merge_key', 'primary_key' or 'columns' argument " - "but it does not exist in the data." + "This can happen if you specify the column manually, for example using the 'merge_key'," + " 'primary_key' or 'columns' argument but it does not exist in the data." ) super().__init__(schema_name, msg) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index d879c21b3c..8b87a7e5fe 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -357,7 +357,9 @@ def is_nullable_column(col: TColumnSchemaBase) -> bool: return col.get("nullable", True) -def find_incomplete_columns(tables: List[TTableSchema]) -> Iterable[Tuple[str, TColumnSchemaBase, bool]]: +def find_incomplete_columns( + tables: List[TTableSchema], +) -> Iterable[Tuple[str, TColumnSchemaBase, bool]]: """Yields (table_name, column, nullable) for all incomplete columns in `tables`""" for table in tables: for col in table["columns"].values(): diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 94d62fc8f1..957f1595da 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -675,9 +675,7 @@ def complete_load(self, load_id: str) -> None: self.schema.naming.normalize_identifier("schema_name"): self.schema.name, self.schema.naming.normalize_identifier("status"): 0, self.schema.naming.normalize_identifier("inserted_at"): str(pendulum.now()), - self.schema.naming.normalize_identifier( - "schema_version_hash" - ): None, # Payload schema must match the target schema. + self.schema.naming.normalize_identifier("schema_version_hash"): None, } ] fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name) @@ -704,7 +702,7 @@ def create_table_chain_completed_followup_jobs( table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - return LanceDBRemoveOrphansFollowupJob.from_table_chain(table_chain) + return LanceDBRemoveOrphansFollowupJob.from_table_chain(table_chain) # type: ignore def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() @@ -772,9 +770,12 @@ def __init__( def run(self) -> None: db_client: DBConnection = self._job_client.db_client fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) - fq_parent_table_name: str = self._job_client.make_qualified_table_name( - self._table_schema["parent"] - ) + try: + fq_parent_table_name: str = self._job_client.make_qualified_table_name( + self._table_schema["parent"] + ) + except KeyError: + fq_parent_table_name = None # The table is a root table. id_field_name: str = self._job_client.config.id_field_name try: @@ -858,7 +859,7 @@ def __init__(self, file_name: str) -> None: def from_table_chain( cls, table_chain: Sequence[TTableSchema], - ) -> List[FollowupJobImpl]: + ) -> List["LanceDBRemoveOrphansFollowupJob"]: jobs = [] for table in table_chain: if table.get("write_disposition") == "merge": diff --git a/dlt/normalize/schema.py b/dlt/normalize/schema.py index 4967fab18f..c01d184c92 100644 --- a/dlt/normalize/schema.py +++ b/dlt/normalize/schema.py @@ -3,13 +3,16 @@ from dlt.common.schema.exceptions import UnboundColumnException from dlt.common import logger + def verify_normalized_schema(schema: Schema) -> None: """Verify the schema is valid for next stage after normalization. 1. Log warning if any incomplete nullable columns are in any data tables 2. Raise `UnboundColumnException` on incomplete non-nullable columns (e.g. missing merge/primary key) """ - for table_name, column, nullable in find_incomplete_columns(schema.data_tables(seen_data_only=True)): + for table_name, column, nullable in find_incomplete_columns( + schema.data_tables(seen_data_only=True) + ): exc = UnboundColumnException(schema.name, table_name, column) if nullable: logger.warning(str(exc)) From 0694859567f083a0a90087195481a033dc3e9231 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 15 Aug 2024 22:58:38 +0200 Subject: [PATCH 53/68] Refactor specialised orphan follow up job back to reference job Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 53 +++++++------------ 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 957f1595da..d53dda9f36 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,4 +1,3 @@ -import os from types import TracebackType from typing import ( List, @@ -54,7 +53,7 @@ TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName +from dlt.common.storages import FileStorage, LoadJobInfo from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -79,7 +78,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, ) -from dlt.destinations.job_impl import FollowupJobImpl +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -692,23 +691,35 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if LanceDBRemoveOrphansFollowupJob.is_remove_orphan_job(file_path): + if ReferenceFollowupJob.is_reference_job(file_path): return LanceDBRemoveOrphansJob(file_path, table) else: - return LanceDBLoadJobWithFollowup(file_path, table) + return LanceDBLoadJob(file_path, table) def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJob]: - return LanceDBRemoveOrphansFollowupJob.from_table_chain(table_chain) # type: ignore + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) + if table_chain[0].get("write_disposition") == "merge": + for table in table_chain: + table_job_paths = [ + job.file_path + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table["name"] + ] + file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) + jobs.append(ReferenceFollowupJob(file_name, table_job_paths)) + return jobs def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LanceDBLoadJob(RunnableLoadJob): +class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema def __init__( @@ -751,10 +762,6 @@ def run(self) -> None: ) -class LanceDBLoadJobWithFollowup(HasFollowupJobs, LanceDBLoadJob): - pass - - class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] @@ -848,27 +855,3 @@ def run(self) -> None: raise DestinationTerminalException( "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." ) from e - - -class LanceDBRemoveOrphansFollowupJob(FollowupJobImpl): - def __init__(self, file_name: str) -> None: - super().__init__(file_name) - self._save_text_file("") - - @classmethod - def from_table_chain( - cls, - table_chain: Sequence[TTableSchema], - ) -> List["LanceDBRemoveOrphansFollowupJob"]: - jobs = [] - for table in table_chain: - if table.get("write_disposition") == "merge": - file_info = ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" - ) - jobs.append(cls(file_info.file_name())) - return jobs - - @staticmethod - def is_remove_orphan_job(file_path: str) -> bool: - return os.path.splitext(file_path)[1][1:] == "reference" From 537a2be7f02e528f587ded57f82736578e45ec22 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sat, 17 Aug 2024 19:30:16 +0200 Subject: [PATCH 54/68] Refactor orphan removal for chunked documents Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 174 +++++++++--------- 1 file changed, 92 insertions(+), 82 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index d53dda9f36..dd59fa981a 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -53,7 +53,7 @@ TColumnSchema, ) from dlt.common.schema.utils import get_columns_names_with_prop -from dlt.common.storages import FileStorage, LoadJobInfo +from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, @@ -692,7 +692,7 @@ def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: if ReferenceFollowupJob.is_reference_job(file_path): - return LanceDBRemoveOrphansJob(file_path, table) + return LanceDBRemoveOrphansJob(file_path) else: return LanceDBLoadJob(file_path, table) @@ -705,14 +705,17 @@ def create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) if table_chain[0].get("write_disposition") == "merge": - for table in table_chain: - table_job_paths = [ - job.file_path - for job in completed_table_chain_jobs - if job.job_file_info.table_name == table["name"] - ] - file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) - jobs.append(ReferenceFollowupJob(file_name, table_job_paths)) + # TODO: Use staging to write deletion records. For now we use only one job. + all_job_paths_ordered = [ + job.file_path + for table in table_chain + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table.get("name") + ] + root_table_file_name = FileStorage.get_file_name_from_file_path( + all_job_paths_ordered[0] + ) + jobs.append(ReferenceFollowupJob(root_table_file_name, all_job_paths_ordered)) return jobs def table_exists(self, table_name: str) -> bool: @@ -762,96 +765,103 @@ def run(self) -> None: ) +# TODO: Implement staging for this step with insert deletes. class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] def __init__( self, file_path: str, - table_schema: TTableSchema, ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None - self._table_schema: TTableSchema = table_schema + self.references = ReferenceFollowupJob.resolve_references(file_path) def run(self) -> None: db_client: DBConnection = self._job_client.db_client - fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) - try: - fq_parent_table_name: str = self._job_client.make_qualified_table_name( - self._table_schema["parent"] - ) - except KeyError: - fq_parent_table_name = None # The table is a root table. id_field_name: str = self._job_client.config.id_field_name - try: - child_table = db_client.open_table(fq_table_name) - child_table.checkout_latest() - if fq_parent_table_name: - parent_table = db_client.open_table(fq_parent_table_name) - parent_table.checkout_latest() - except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" - ) from e - - try: - if fq_parent_table_name: - # Chunks and embeddings in child table. - parent_ids = set( - pc.unique( - parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] - ).to_pylist() - ) - child_ids = set( - pc.unique( - child_table.to_lance().to_table(columns=["_dlt_parent_id"])[ - "_dlt_parent_id" - ] - ).to_pylist() + # We don't all insert jobs for each table using this method. + table_lineage: List[TTableSchema] = [] + for file_path_ in self.references: + table = self._schema.get_table(ParsedLoadJobFileName.parse(file_path_).table_name) + if table["name"] not in [table_["name"] for table_ in table_lineage]: + table_lineage.append(table) + + for table in table_lineage: + fq_table_name: str = self._job_client.make_qualified_table_name(table["name"]) + try: + fq_parent_table_name: str = self._job_client.make_qualified_table_name( + table["parent"] ) + except KeyError: + fq_parent_table_name = None # The table is a root table. + + try: + child_table = db_client.open_table(fq_table_name) + child_table.checkout_latest() + if fq_parent_table_name: + parent_table = db_client.open_table(fq_parent_table_name) + parent_table.checkout_latest() + except FileNotFoundError as e: + raise DestinationTransientException( + "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" + ) from e + + try: + if fq_parent_table_name: + # Chunks and embeddings in child table. + parent_ids = set( + pc.unique( + parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] + ).to_pylist() + ) + child_ids = set( + pc.unique( + child_table.to_lance().to_table(columns=["_dlt_parent_id"])[ + "_dlt_parent_id" + ] + ).to_pylist() + ) - if orphaned_ids := child_ids - parent_ids: - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") + if orphaned_ids := child_ids - parent_ids: + if len(orphaned_ids) > 1: + child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") - else: - # Chunks and embeddings in the root table. - document_id_field = get_columns_names_with_prop( - self._table_schema, DOCUMENT_ID_HINT - ) - if document_id_field and get_columns_names_with_prop( - self._table_schema, "primary_key" - ): - raise SystemConfigurationException( - "You CANNOT specify a primary key AND a document ID hint for the same" - " resource when using merge disposition." - ) + else: + # Chunks and embeddings in the root table. + document_id_field = get_columns_names_with_prop(table, DOCUMENT_ID_HINT) + if document_id_field and get_columns_names_with_prop(table, "primary_key"): + raise SystemConfigurationException( + "You CANNOT specify a primary key AND a document ID hint for the same" + " resource when using merge disposition." + ) - # If document ID is defined, we use this as the sole grouping key to identify stale chunks, - # else fallback to the compound `id_field_name`. - grouping_key = document_id_field or id_field_name - grouping_key = grouping_key if isinstance(grouping_key, list) else [grouping_key] - child_table_arrow: pa.Table = child_table.to_lance().to_table( - columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] - ) + # If document ID is defined, we use this as the sole grouping key to identify stale chunks, + # else fallback to the compound `id_field_name`. + grouping_key = document_id_field or id_field_name + grouping_key = ( + grouping_key if isinstance(grouping_key, list) else [grouping_key] + ) + child_table_arrow: pa.Table = child_table.to_lance().to_table( + columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] + ) - grouped = child_table_arrow.group_by(grouping_key).aggregate( - [("_dlt_load_id", "max")] - ) - joined = child_table_arrow.join(grouped, keys=grouping_key) - orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) - orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() + grouped = child_table_arrow.group_by(grouping_key).aggregate( + [("_dlt_load_id", "max")] + ) + joined = child_table_arrow.join(grouped, keys=grouping_key) + orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) + orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") + if len(orphaned_ids) > 1: + child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") + elif len(orphaned_ids) == 1: + child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") - except ArrowInvalid as e: - raise DestinationTerminalException( - "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." - ) from e + except ArrowInvalid as e: + raise DestinationTerminalException( + "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." + ) from e From 3d2530656325b9cbfe8fa2c3fc25625352a355ab Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 18 Aug 2024 21:13:44 +0200 Subject: [PATCH 55/68] Fix dlt system table check for name instead of object Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 2 +- dlt/destinations/impl/lancedb/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index dd59fa981a..ecdd22ca56 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -748,7 +748,7 @@ def run(self) -> None: with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) - if self._load_table not in self._schema.dlt_tables(): + if self._load_table["name"] not in self._schema.dlt_table_names(): arrow_table = generate_arrow_uuid_column( arrow_table, unique_identifiers=unique_identifiers, diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 466111dc98..37303686df 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -27,7 +27,7 @@ def generate_arrow_uuid_column( Args: table (pa.Table): PyArrow table to generate UUIDs for. - unique_identifiers (List[str]): A list of unique identifier column names. + unique_identifiers (Sequence[str]): A list of unique identifier column names. id_field_name (str): Name of the new UUID column. table_name (str): Name of the table. From 2ee8da11f3c720fb6190db6af69debf994d91bd3 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Sun, 18 Aug 2024 23:55:38 +0200 Subject: [PATCH 56/68] Implement staging methods Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index ecdd22ca56..a0cb6bd5e4 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,3 +1,4 @@ +import contextlib from types import TracebackType from typing import ( List, @@ -12,6 +13,7 @@ Sequence, TYPE_CHECKING, Set, + Iterator, ) import lancedb # type: ignore @@ -41,6 +43,7 @@ FollowupJob, LoadJob, HasFollowupJobs, + WithStagingDataset, ) from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta @@ -210,10 +213,12 @@ def write_to_db( ) from e -class LanceDBClient(JobClientBase, WithStateSync): +class LanceDBClient(JobClientBase, WithStateSync, WithStagingDataset): """LanceDB destination handler.""" model_func: TextEmbeddingFunction + """The embedder callback used for each chunk.""" + dataset_name: str def __init__( self, @@ -231,6 +236,7 @@ def __init__( self.registry = EmbeddingFunctionRegistry.get_instance() self.type_mapper = LanceDBTypeMapper(self.capabilities) self.sentinel_table_name = config.sentinel_table_name + self.dataset_name = self.config.normalize_dataset_name(self.schema) embedding_model_provider = self.config.embedding_model_provider @@ -259,10 +265,6 @@ def __init__( self.vector_field_name = self.config.vector_field_name self.id_field_name = self.config.id_field_name - @property - def dataset_name(self) -> str: - return self.config.normalize_dataset_name(self.schema) - @property def sentinel_table(self) -> str: return self.make_qualified_table_name(self.sentinel_table_name) @@ -721,6 +723,20 @@ def create_table_chain_completed_followup_jobs( def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() + @contextlib.contextmanager + def with_staging_dataset(self) -> Iterator["LanceDBClient"]: + current_dataset_name = self.dataset_name + try: + self.dataset_name = self.schema.naming.normalize_table_identifier( + f"{current_dataset_name}_staging" + ) + yield self + finally: + self.dataset_name = current_dataset_name + + def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + return table["write_disposition"] == "merge" + class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema From 2947d55e94acd56769a3d58c4bd6db4e856d6569 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 19 Aug 2024 15:40:43 +0200 Subject: [PATCH 57/68] Override staging client methods Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a0cb6bd5e4..53da145142 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -707,7 +707,6 @@ def create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) if table_chain[0].get("write_disposition") == "merge": - # TODO: Use staging to write deletion records. For now we use only one job. all_job_paths_ordered = [ job.file_path for table in table_chain @@ -735,7 +734,7 @@ def with_staging_dataset(self) -> Iterator["LanceDBClient"]: self.dataset_name = current_dataset_name def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: - return table["write_disposition"] == "merge" + return False class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): From 1fcce519649b5b41d7a0f707020d99cdbbf4276a Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 20 Aug 2024 13:46:11 +0200 Subject: [PATCH 58/68] Override staging client methods Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 53da145142..1081aae576 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -780,7 +780,6 @@ def run(self) -> None: ) -# TODO: Implement staging for this step with insert deletes. class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] From 8849f11cb56914ef89e5b0ba4ec59d346d1185a6 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 20 Aug 2024 21:23:22 +0200 Subject: [PATCH 59/68] Delete with inserts Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 155 +++++++----------- dlt/destinations/impl/lancedb/schema.py | 5 +- 2 files changed, 61 insertions(+), 99 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 1081aae576..82cbdcba26 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -19,13 +19,13 @@ import lancedb # type: ignore import lancedb.table # type: ignore import pyarrow as pa -import pyarrow.compute as pc import pyarrow.parquet as pq from lancedb import DBConnection +from lancedb.common import DATA # type: ignore from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore from numpy import ndarray -from pyarrow import Array, ChunkedArray, ArrowInvalid +from pyarrow import Array, ChunkedArray, ArrowInvalid, RecordBatchReader from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext @@ -45,7 +45,6 @@ HasFollowupJobs, WithStagingDataset, ) -from dlt.common.exceptions import SystemConfigurationException from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import ( @@ -57,7 +56,6 @@ ) from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, LoadJobInfo, ParsedLoadJobFileName -from dlt.common.typing import DictStrAny from dlt.destinations.impl.lancedb.configuration import ( LanceDBClientConfiguration, ) @@ -66,7 +64,6 @@ ) from dlt.destinations.impl.lancedb.lancedb_adapter import ( VECTORIZE_HINT, - DOCUMENT_ID_HINT, ) from dlt.destinations.impl.lancedb.schema import ( make_arrow_field_schema, @@ -75,6 +72,7 @@ NULL_SCHEMA, TArrowField, arrow_datatype_to_fusion_datatype, + TTableLineage, ) from dlt.destinations.impl.lancedb.utils import ( get_unique_identifiers_from_table_schema, @@ -160,14 +158,15 @@ def from_db_type( return super().from_db_type(cast(str, db_type), precision, scale) -def write_to_db( - records: Union[pa.Table, List[DictStrAny]], +def write_records( + records: DATA, /, *, db_client: DBConnection, table_name: str, write_disposition: Optional[TWriteDisposition] = "append", id_field_name: Optional[str] = None, + remove_orphans: Optional[bool] = False, ) -> None: """Inserts records into a LanceDB table with automatic embedding computation. @@ -177,6 +176,7 @@ def write_to_db( table_name: The name of the table to insert into. id_field_name: The name of the ID field for update/merge operations. write_disposition: The write disposition - one of 'skip', 'append', 'replace', 'merge'. + remove_orphans (bool): Whether to remove orphans after insertion or not (only merge disposition). Raises: ValueError: If the write disposition is unsupported, or `id_field_name` is not @@ -199,9 +199,12 @@ def write_to_db( elif write_disposition == "merge": if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") - tbl.merge_insert( - id_field_name - ).when_matched_update_all().when_not_matched_insert_all().execute(records) + if remove_orphans: + tbl.merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) + else: + tbl.merge_insert( + id_field_name + ).when_matched_update_all().when_not_matched_insert_all().execute(records) else: raise DestinationTerminalException( f"Unsupported write disposition {write_disposition} for LanceDB Destination - batch" @@ -534,7 +537,7 @@ def update_schema_in_storage(self) -> None: "write_disposition" ) - write_to_db( + write_records( records, db_client=self.db_client, table_name=fq_version_table_name, @@ -683,7 +686,7 @@ def complete_load(self, load_id: str) -> None: write_disposition = self.schema.get_table(self.schema.loads_table_name).get( "write_disposition" ) - write_to_db( + write_records( records, db_client=self.db_client, table_name=fq_loads_table_name, @@ -771,7 +774,7 @@ def run(self) -> None: id_field_name=id_field_name, ) - write_to_db( + write_records( arrow_table, db_client=db_client, table_name=fq_table_name, @@ -793,89 +796,47 @@ def __init__( def run(self) -> None: db_client: DBConnection = self._job_client.db_client - id_field_name: str = self._job_client.config.id_field_name - - # We don't all insert jobs for each table using this method. - table_lineage: List[TTableSchema] = [] - for file_path_ in self.references: - table = self._schema.get_table(ParsedLoadJobFileName.parse(file_path_).table_name) - if table["name"] not in [table_["name"] for table_ in table_lineage]: - table_lineage.append(table) - - for table in table_lineage: - fq_table_name: str = self._job_client.make_qualified_table_name(table["name"]) - try: - fq_parent_table_name: str = self._job_client.make_qualified_table_name( - table["parent"] + table_lineage: List[Tuple[TTableSchema, str, str]] = [ + ( + self._schema.get_table(ParsedLoadJobFileName.parse(file_path_).table_name), + ParsedLoadJobFileName.parse(file_path_).table_name, + file_path_, + ) + for file_path_ in self.references + ] + source_table_id_field_name = "_dlt_id" + + for table, table_name, table_path in table_lineage: + target_is_root_table = "parent" not in table + fq_table_name = self._job_client.make_qualified_table_name(table_name) + target_table_schema = pq.read_schema(table_path) + + if target_is_root_table: + target_table_id_field_name = "_dlt_id" + arrow_ds = pa.dataset.dataset(table_path) + else: + # TODO: change schema of source table id to math target table id. + target_table_id_field_name = "_dlt_parent_id" + parent_table_path = self.get_parent_path(table_lineage, table.get("parent")) + arrow_ds = pa.dataset.dataset(parent_table_path) + + arrow_rbr: RecordBatchReader + with arrow_ds.scanner( + columns=[source_table_id_field_name], + batch_size=BATCH_PROCESS_CHUNK_SIZE, + ).to_reader() as arrow_rbr: + records: pa.Table = arrow_rbr.read_all() + records_with_conforming_schema = records.join(target_table_schema.empty_table(), keys=source_table_id_field_name) + + write_records( + records_with_conforming_schema, + db_client=db_client, + id_field_name=target_table_id_field_name, + table_name=fq_table_name, + write_disposition="merge", + remove_orphans=True, ) - except KeyError: - fq_parent_table_name = None # The table is a root table. - - try: - child_table = db_client.open_table(fq_table_name) - child_table.checkout_latest() - if fq_parent_table_name: - parent_table = db_client.open_table(fq_parent_table_name) - parent_table.checkout_latest() - except FileNotFoundError as e: - raise DestinationTransientException( - "Couldn't open lancedb database. Orphan removal WILL BE RETRIED" - ) from e - - try: - if fq_parent_table_name: - # Chunks and embeddings in child table. - parent_ids = set( - pc.unique( - parent_table.to_lance().to_table(columns=["_dlt_id"])["_dlt_id"] - ).to_pylist() - ) - child_ids = set( - pc.unique( - child_table.to_lance().to_table(columns=["_dlt_parent_id"])[ - "_dlt_parent_id" - ] - ).to_pylist() - ) - - if orphaned_ids := child_ids - parent_ids: - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'") - else: - # Chunks and embeddings in the root table. - document_id_field = get_columns_names_with_prop(table, DOCUMENT_ID_HINT) - if document_id_field and get_columns_names_with_prop(table, "primary_key"): - raise SystemConfigurationException( - "You CANNOT specify a primary key AND a document ID hint for the same" - " resource when using merge disposition." - ) - - # If document ID is defined, we use this as the sole grouping key to identify stale chunks, - # else fallback to the compound `id_field_name`. - grouping_key = document_id_field or id_field_name - grouping_key = ( - grouping_key if isinstance(grouping_key, list) else [grouping_key] - ) - child_table_arrow: pa.Table = child_table.to_lance().to_table( - columns=[*grouping_key, "_dlt_load_id", "_dlt_id"] - ) - - grouped = child_table_arrow.group_by(grouping_key).aggregate( - [("_dlt_load_id", "max")] - ) - joined = child_table_arrow.join(grouped, keys=grouping_key) - orphaned_mask = pc.not_equal(joined["_dlt_load_id"], joined["_dlt_load_id_max"]) - orphaned_ids = joined.filter(orphaned_mask).column("_dlt_id").to_pylist() - - if len(orphaned_ids) > 1: - child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}") - elif len(orphaned_ids) == 1: - child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'") - - except ArrowInvalid as e: - raise DestinationTerminalException( - "Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED." - ) from e + @staticmethod + def get_parent_path(table_lineage: TTableLineage, table: str) -> Optional[str]: + return next(entry[1] for entry in table_lineage if entry[1] == table) diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index db624aeb12..ff6f76c07a 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -3,7 +3,7 @@ from typing import ( List, cast, - Optional, + Optional, Tuple, ) import pyarrow as pa @@ -11,7 +11,7 @@ from typing_extensions import TypeAlias from dlt.common.json import json -from dlt.common.schema import Schema, TColumnSchema +from dlt.common.schema import Schema, TColumnSchema, TTableSchema from dlt.common.typing import DictStrAny from dlt.destinations.type_mapping import TypeMapper @@ -21,6 +21,7 @@ TArrowField: TypeAlias = pa.Field NULL_SCHEMA: TArrowSchema = pa.schema([]) """Empty pyarrow Schema with no fields.""" +TTableLineage: TypeAlias = List[Tuple[TTableSchema, str, str]] def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: From c7098fdf2c2f54f4c17f9841ae86859a1aaef15a Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 20 Aug 2024 21:23:48 +0200 Subject: [PATCH 60/68] Keep with batch reader Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 82cbdcba26..ea1a9796ac 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -825,11 +825,9 @@ def run(self) -> None: columns=[source_table_id_field_name], batch_size=BATCH_PROCESS_CHUNK_SIZE, ).to_reader() as arrow_rbr: - records: pa.Table = arrow_rbr.read_all() - records_with_conforming_schema = records.join(target_table_schema.empty_table(), keys=source_table_id_field_name) write_records( - records_with_conforming_schema, + arrow_rbr, db_client=db_client, id_field_name=target_table_id_field_name, table_name=fq_table_name, From d8ddcae9e20c875486a7ea69b1d1eef0335cd4c3 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 22 Aug 2024 13:53:43 +0200 Subject: [PATCH 61/68] Remove Lancedb client's staging implementation Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index ea1a9796ac..bb512bcf2e 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -216,7 +216,7 @@ def write_records( ) from e -class LanceDBClient(JobClientBase, WithStateSync, WithStagingDataset): +class LanceDBClient(JobClientBase, WithStateSync): """LanceDB destination handler.""" model_func: TextEmbeddingFunction @@ -725,19 +725,6 @@ def create_table_chain_completed_followup_jobs( def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() - @contextlib.contextmanager - def with_staging_dataset(self) -> Iterator["LanceDBClient"]: - current_dataset_name = self.dataset_name - try: - self.dataset_name = self.schema.naming.normalize_table_identifier( - f"{current_dataset_name}_staging" - ) - yield self - finally: - self.dataset_name = current_dataset_name - - def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: - return False class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): From 17137a6033826c6189da41db0fc8588ab74da48c Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 22 Aug 2024 15:13:23 +0200 Subject: [PATCH 62/68] Insert in memory arrow table. This will be optimized Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 69 +++++++++++-------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index bb512bcf2e..3ee9e513b9 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,4 +1,3 @@ -import contextlib from types import TracebackType from typing import ( List, @@ -13,7 +12,6 @@ Sequence, TYPE_CHECKING, Set, - Iterator, ) import lancedb # type: ignore @@ -25,7 +23,7 @@ from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore from numpy import ndarray -from pyarrow import Array, ChunkedArray, ArrowInvalid, RecordBatchReader +from pyarrow import Array, ChunkedArray, ArrowInvalid from dlt.common import json, pendulum, logger from dlt.common.destination import DestinationCapabilitiesContext @@ -43,7 +41,6 @@ FollowupJob, LoadJob, HasFollowupJobs, - WithStagingDataset, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -200,6 +197,7 @@ def write_records( if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") if remove_orphans: + # tbl.to_lance().merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) tbl.merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) else: tbl.merge_insert( @@ -726,7 +724,6 @@ def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() - class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema @@ -791,36 +788,52 @@ def run(self) -> None: ) for file_path_ in self.references ] - source_table_id_field_name = "_dlt_id" - for table, table_name, table_path in table_lineage: - target_is_root_table = "parent" not in table - fq_table_name = self._job_client.make_qualified_table_name(table_name) - target_table_schema = pq.read_schema(table_path) + for target_table, target_table_name, target_table_path in table_lineage: + target_is_root_table = "parent" not in target_table + fq_table_name = self._job_client.make_qualified_table_name(target_table_name) if target_is_root_table: target_table_id_field_name = "_dlt_id" - arrow_ds = pa.dataset.dataset(table_path) + # arrow_ds = pa.dataset.dataset(table_path) + with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + + # Append ID Field, which is only defined in the LanceDB table. + payload_arrow_table = payload_arrow_table.append_column( + pa.field(self._job_client.id_field_name, pa.string()), + pa.array([""] * payload_arrow_table.num_rows, type=pa.string()), + ) else: # TODO: change schema of source table id to math target table id. target_table_id_field_name = "_dlt_parent_id" - parent_table_path = self.get_parent_path(table_lineage, table.get("parent")) - arrow_ds = pa.dataset.dataset(parent_table_path) - - arrow_rbr: RecordBatchReader - with arrow_ds.scanner( - columns=[source_table_id_field_name], - batch_size=BATCH_PROCESS_CHUNK_SIZE, - ).to_reader() as arrow_rbr: - - write_records( - arrow_rbr, - db_client=db_client, - id_field_name=target_table_id_field_name, - table_name=fq_table_name, - write_disposition="merge", - remove_orphans=True, - ) + parent_table_path = self.get_parent_path(table_lineage, target_table.get("parent")) + # arrow_ds = pa.dataset.dataset(parent_table_path) + with FileStorage.open_zipsafe_ro(parent_table_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + + # arrow_rbr: RecordBatchReader + # with arrow_ds.scanner( + # columns=[source_table_id_field_name], + # batch_size=BATCH_PROCESS_CHUNK_SIZE, + # ).to_reader() as arrow_rbr: + # with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: + # target_table_arrow_schema: pa.Schema = pq.read_schema(f) + + # payload_arrow_table_with_conforming_schema = payload_arrow_table.join( + # target_table_arrow_schema.empty_table(), + # keys=target_table_id_field_name, + # ) + + write_records( + # payload_arrow_table_with_conforming_schema, + payload_arrow_table, + db_client=db_client, + id_field_name=target_table_id_field_name, + table_name=fq_table_name, + write_disposition="merge", + remove_orphans=True, + ) @staticmethod def get_parent_path(table_lineage: TTableLineage, table: str) -> Optional[str]: From 53d896a1ac08d031a71b96b99b981db822805b1b Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 14:08:15 +0200 Subject: [PATCH 63/68] Rename classes to the new job implementation classes Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 131 +++++++----------- dlt/destinations/impl/lancedb/utils.py | 17 +++ 2 files changed, 67 insertions(+), 81 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 3ee9e513b9..d0ff3c29dd 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -38,9 +38,9 @@ RunnableLoadJob, StorageSchemaInfo, StateInfo, - FollowupJob, LoadJob, HasFollowupJobs, + FollowupJobRequest, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -75,8 +75,9 @@ get_unique_identifiers_from_table_schema, set_non_standard_providers_environment_variables, generate_arrow_uuid_column, + get_default_arrow_value, ) -from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -197,7 +198,6 @@ def write_records( if not id_field_name: raise ValueError("To perform a merge update, 'id_field_name' must be specified.") if remove_orphans: - # tbl.to_lance().merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) tbl.merge_insert(id_field_name).when_not_matched_by_source_delete().execute(records) else: tbl.merge_insert( @@ -694,30 +694,17 @@ def complete_load(self, load_id: str) -> None: def create_load_job( self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - if ReferenceFollowupJob.is_reference_job(file_path): + if ReferenceFollowupJobRequest.is_reference_job(file_path): return LanceDBRemoveOrphansJob(file_path) else: return LanceDBLoadJob(file_path, table) - def create_table_chain_completed_followup_jobs( - self, - table_chain: Sequence[TTableSchema], - completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[FollowupJob]: - jobs = super().create_table_chain_completed_followup_jobs( - table_chain, completed_table_chain_jobs - ) - if table_chain[0].get("write_disposition") == "merge": - all_job_paths_ordered = [ - job.file_path - for table in table_chain - for job in completed_table_chain_jobs - if job.job_file_info.table_name == table.get("name") - ] - root_table_file_name = FileStorage.get_file_name_from_file_path( - all_job_paths_ordered[0] - ) - jobs.append(ReferenceFollowupJob(root_table_file_name, all_job_paths_ordered)) + def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJobRequest]: + jobs = super().create_table_chain_completed_followup_jobs(table_chain, completed_table_chain_jobs) + if table_chain[0].get("write_disposition")=="merge": + all_job_paths_ordered = [job.file_path for table in table_chain for job in completed_table_chain_jobs if job.job_file_info.table_name==table.get("name")] + root_table_file_name = FileStorage.get_file_name_from_file_path(all_job_paths_ordered[0]) + jobs.append(ReferenceFollowupJobRequest(root_table_file_name, all_job_paths_ordered)) return jobs def table_exists(self, table_name: str) -> bool: @@ -727,11 +714,7 @@ def table_exists(self, table_name: str) -> bool: class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema - def __init__( - self, - file_path: str, - table_schema: TTableSchema, - ) -> None: + def __init__(self, file_path: str, table_schema: TTableSchema, ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None self._table_schema: TTableSchema = table_schema @@ -740,43 +723,25 @@ def run(self) -> None: db_client: DBConnection = self._job_client.db_client fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) id_field_name: str = self._job_client.config.id_field_name - unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( - self._load_table - ) - write_disposition: TWriteDisposition = cast( - TWriteDisposition, self._load_table.get("write_disposition", "append") - ) + unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema(self._load_table) + write_disposition: TWriteDisposition = cast(TWriteDisposition, self._load_table.get("write_disposition", "append")) with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) if self._load_table["name"] not in self._schema.dlt_table_names(): - arrow_table = generate_arrow_uuid_column( - arrow_table, - unique_identifiers=unique_identifiers, - table_name=fq_table_name, - id_field_name=id_field_name, - ) + arrow_table = generate_arrow_uuid_column(arrow_table, unique_identifiers=unique_identifiers, table_name=fq_table_name, id_field_name=id_field_name, ) - write_records( - arrow_table, - db_client=db_client, - table_name=fq_table_name, - write_disposition=write_disposition, - id_field_name=id_field_name, - ) + write_records(arrow_table, db_client=db_client, table_name=fq_table_name, write_disposition=write_disposition, id_field_name=id_field_name, ) class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] - def __init__( - self, - file_path: str, - ) -> None: + def __init__(self, file_path: str, ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None - self.references = ReferenceFollowupJob.resolve_references(file_path) + self.references = ReferenceFollowupJobRequest.resolve_references(file_path) def run(self) -> None: db_client: DBConnection = self._job_client.db_client @@ -795,38 +760,42 @@ def run(self) -> None: if target_is_root_table: target_table_id_field_name = "_dlt_id" - # arrow_ds = pa.dataset.dataset(table_path) - with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: - payload_arrow_table: pa.Table = pq.read_table(f) - - # Append ID Field, which is only defined in the LanceDB table. - payload_arrow_table = payload_arrow_table.append_column( - pa.field(self._job_client.id_field_name, pa.string()), - pa.array([""] * payload_arrow_table.num_rows, type=pa.string()), - ) + file_path = target_table_path else: - # TODO: change schema of source table id to math target table id. target_table_id_field_name = "_dlt_parent_id" - parent_table_path = self.get_parent_path(table_lineage, target_table.get("parent")) - # arrow_ds = pa.dataset.dataset(parent_table_path) - with FileStorage.open_zipsafe_ro(parent_table_path, mode="rb") as f: - payload_arrow_table: pa.Table = pq.read_table(f) - - # arrow_rbr: RecordBatchReader - # with arrow_ds.scanner( - # columns=[source_table_id_field_name], - # batch_size=BATCH_PROCESS_CHUNK_SIZE, - # ).to_reader() as arrow_rbr: - # with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: - # target_table_arrow_schema: pa.Schema = pq.read_schema(f) - - # payload_arrow_table_with_conforming_schema = payload_arrow_table.join( - # target_table_arrow_schema.empty_table(), - # keys=target_table_id_field_name, - # ) + file_path = self.get_parent_path(table_lineage, target_table.get("parent")) + + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) + + # Get target table schema + with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: + target_table_schema: pa.Schema = pq.read_schema(f) + + # LanceDB requires the payload to have all fields populated, even if we don't intend to use them in our merge operation. + # Unfortunately, we can't just create NULL fields; else LanceDB always truncates the target using `when_not_matched_by_source_delete`. + schema_difference = pa.schema( + set(target_table_schema) - set(payload_arrow_table.schema) + ) + for field in schema_difference: + try: + default_value = get_default_arrow_value(field.type) + default_array = pa.array( + [default_value] * payload_arrow_table.num_rows, type=field.type + ) + payload_arrow_table = payload_arrow_table.append_column(field, default_array) + except ValueError as e: + logger.warn(f"{e}. Using null values for field '{field.name}'.") + null_array = pa.array([None] * payload_arrow_table.num_rows, type=field.type) + payload_arrow_table = payload_arrow_table.append_column(field, null_array) + + # TODO: Remove special field, we don't need it. + payload_arrow_table = payload_arrow_table.append_column( + pa.field(self._job_client.id_field_name, pa.string()), + pa.array([""] * payload_arrow_table.num_rows, type=pa.string()), + ) write_records( - # payload_arrow_table_with_conforming_schema, payload_arrow_table, db_client=db_client, id_field_name=target_table_id_field_name, @@ -837,4 +806,4 @@ def run(self) -> None: @staticmethod def get_parent_path(table_lineage: TTableLineage, table: str) -> Optional[str]: - return next(entry[1] for entry in table_lineage if entry[1] == table) + return next(entry[2] for entry in table_lineage if entry[1] == table) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 37303686df..fe2bfa48bd 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -1,5 +1,6 @@ import os import uuid +from datetime import date, datetime from typing import Sequence, Union, Dict, List import pyarrow as pa @@ -74,3 +75,19 @@ def set_non_standard_providers_environment_variables( ) -> None: if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" + +def get_default_arrow_value(field_type): + if pa.types.is_integer(field_type): + return 0 + elif pa.types.is_floating(field_type): + return 0.0 + elif pa.types.is_string(field_type): + return "" + elif pa.types.is_boolean(field_type): + return False + elif pa.types.is_date(field_type): + return date.today() + elif pa.types.is_timestamp(field_type): + return datetime.now() + else: + raise ValueError(f"Unsupported data type: {field_type}") \ No newline at end of file From 26ba0f57f6db6d69864e0aec21103cac5d07bbb3 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 15:03:49 +0200 Subject: [PATCH 64/68] Use namedtuple for table chain to improve readability Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 86 +++++++++++++------ dlt/destinations/impl/lancedb/schema.py | 6 +- dlt/destinations/impl/lancedb/utils.py | 8 +- 3 files changed, 71 insertions(+), 29 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index d0ff3c29dd..a099974d56 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -70,6 +70,7 @@ TArrowField, arrow_datatype_to_fusion_datatype, TTableLineage, + TableJob, ) from dlt.destinations.impl.lancedb.utils import ( get_unique_identifiers_from_table_schema, @@ -699,11 +700,24 @@ def create_load_job( else: return LanceDBLoadJob(file_path, table) - def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJobRequest]: - jobs = super().create_table_chain_completed_followup_jobs(table_chain, completed_table_chain_jobs) - if table_chain[0].get("write_disposition")=="merge": - all_job_paths_ordered = [job.file_path for table in table_chain for job in completed_table_chain_jobs if job.job_file_info.table_name==table.get("name")] - root_table_file_name = FileStorage.get_file_name_from_file_path(all_job_paths_ordered[0]) + def create_table_chain_completed_followup_jobs( + self, + table_chain: Sequence[TTableSchema], + completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, + ) -> List[FollowupJobRequest]: + jobs = super().create_table_chain_completed_followup_jobs( + table_chain, completed_table_chain_jobs + ) + if table_chain[0].get("write_disposition") == "merge": + all_job_paths_ordered = [ + job.file_path + for table in table_chain + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table.get("name") + ] + root_table_file_name = FileStorage.get_file_name_from_file_path( + all_job_paths_ordered[0] + ) jobs.append(ReferenceFollowupJobRequest(root_table_file_name, all_job_paths_ordered)) return jobs @@ -714,7 +728,11 @@ def table_exists(self, table_name: str) -> bool: class LanceDBLoadJob(RunnableLoadJob, HasFollowupJobs): arrow_schema: TArrowSchema - def __init__(self, file_path: str, table_schema: TTableSchema, ) -> None: + def __init__( + self, + file_path: str, + table_schema: TTableSchema, + ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None self._table_schema: TTableSchema = table_schema @@ -723,53 +741,73 @@ def run(self) -> None: db_client: DBConnection = self._job_client.db_client fq_table_name: str = self._job_client.make_qualified_table_name(self._table_schema["name"]) id_field_name: str = self._job_client.config.id_field_name - unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema(self._load_table) - write_disposition: TWriteDisposition = cast(TWriteDisposition, self._load_table.get("write_disposition", "append")) + unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema( + self._load_table + ) + write_disposition: TWriteDisposition = cast( + TWriteDisposition, self._load_table.get("write_disposition", "append") + ) with FileStorage.open_zipsafe_ro(self._file_path, mode="rb") as f: arrow_table: pa.Table = pq.read_table(f) if self._load_table["name"] not in self._schema.dlt_table_names(): - arrow_table = generate_arrow_uuid_column(arrow_table, unique_identifiers=unique_identifiers, table_name=fq_table_name, id_field_name=id_field_name, ) + arrow_table = generate_arrow_uuid_column( + arrow_table, + unique_identifiers=unique_identifiers, + table_name=fq_table_name, + id_field_name=id_field_name, + ) - write_records(arrow_table, db_client=db_client, table_name=fq_table_name, write_disposition=write_disposition, id_field_name=id_field_name, ) + write_records( + arrow_table, + db_client=db_client, + table_name=fq_table_name, + write_disposition=write_disposition, + id_field_name=id_field_name, + ) class LanceDBRemoveOrphansJob(RunnableLoadJob): orphaned_ids: Set[str] - def __init__(self, file_path: str, ) -> None: + def __init__( + self, + file_path: str, + ) -> None: super().__init__(file_path) self._job_client: "LanceDBClient" = None self.references = ReferenceFollowupJobRequest.resolve_references(file_path) def run(self) -> None: db_client: DBConnection = self._job_client.db_client - table_lineage: List[Tuple[TTableSchema, str, str]] = [ - ( - self._schema.get_table(ParsedLoadJobFileName.parse(file_path_).table_name), - ParsedLoadJobFileName.parse(file_path_).table_name, - file_path_, + table_lineage: TTableLineage = [ + TableJob( + table_schema=self._schema.get_table( + ParsedLoadJobFileName.parse(file_path_).table_name + ), + table_name=ParsedLoadJobFileName.parse(file_path_).table_name, + file_path=file_path_, ) for file_path_ in self.references ] - for target_table, target_table_name, target_table_path in table_lineage: - target_is_root_table = "parent" not in target_table - fq_table_name = self._job_client.make_qualified_table_name(target_table_name) + for job in table_lineage: + target_is_root_table = "parent" not in job.table_schema + fq_table_name = self._job_client.make_qualified_table_name(job.table_name) if target_is_root_table: target_table_id_field_name = "_dlt_id" - file_path = target_table_path + file_path = job.file_path else: target_table_id_field_name = "_dlt_parent_id" - file_path = self.get_parent_path(table_lineage, target_table.get("parent")) + file_path = self.get_parent_path(table_lineage, job.table_schema.get("parent")) with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: payload_arrow_table: pa.Table = pq.read_table(f) # Get target table schema - with FileStorage.open_zipsafe_ro(target_table_path, mode="rb") as f: + with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: target_table_schema: pa.Schema = pq.read_schema(f) # LanceDB requires the payload to have all fields populated, even if we don't intend to use them in our merge operation. @@ -805,5 +843,5 @@ def run(self) -> None: ) @staticmethod - def get_parent_path(table_lineage: TTableLineage, table: str) -> Optional[str]: - return next(entry[2] for entry in table_lineage if entry[1] == table) + def get_parent_path(table_lineage: TTableLineage, table: str) -> Any: + return next(entry.file_path for entry in table_lineage if entry.table_name == table) diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index ff6f76c07a..6cc562038e 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -1,5 +1,5 @@ """Utilities for creating arrow schemas from table schemas.""" - +from collections import namedtuple from typing import ( List, cast, @@ -21,7 +21,8 @@ TArrowField: TypeAlias = pa.Field NULL_SCHEMA: TArrowSchema = pa.schema([]) """Empty pyarrow Schema with no fields.""" -TTableLineage: TypeAlias = List[Tuple[TTableSchema, str, str]] +TableJob = namedtuple('TableJob', ['table_schema', 'table_name', 'file_path']) +TTableLineage: TypeAlias = List[TableJob] def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: @@ -102,3 +103,4 @@ def arrow_datatype_to_fusion_datatype(arrow_type: TArrowSchema) -> str: return "TIMESTAMP" return type_map.get(arrow_type, "UNKNOWN") + diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index fe2bfa48bd..e1847120bd 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -9,6 +9,7 @@ from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider +from dlt.destinations.impl.lancedb.schema import TArrowDataType PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -62,7 +63,7 @@ def get_unique_identifiers_from_table_schema(table_schema: TTableSchema) -> List """ primary_keys = get_columns_names_with_prop(table_schema, "primary_key") merge_keys = [] - if table_schema.get("write_disposition") == "merge": + if table_schema.get("write_disposition")=="merge": merge_keys = get_columns_names_with_prop(table_schema, "merge_key") if join_keys := list(set(primary_keys + merge_keys)): return join_keys @@ -76,7 +77,8 @@ def set_non_standard_providers_environment_variables( if embedding_model_provider in PROVIDER_ENVIRONMENT_VARIABLES_MAP: os.environ[PROVIDER_ENVIRONMENT_VARIABLES_MAP[embedding_model_provider]] = api_key or "" -def get_default_arrow_value(field_type): + +def get_default_arrow_value(field_type: TArrowDataType) -> object: if pa.types.is_integer(field_type): return 0 elif pa.types.is_floating(field_type): @@ -90,4 +92,4 @@ def get_default_arrow_value(field_type): elif pa.types.is_timestamp(field_type): return datetime.now() else: - raise ValueError(f"Unsupported data type: {field_type}") \ No newline at end of file + raise ValueError(f"Unsupported data type: {field_type}") From 06e04d9fe6fbb76095129d93f5872c3c42566ca8 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 17:43:03 +0200 Subject: [PATCH 65/68] Remove orphans by loading all ancestor IDs simultaneously Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 35 +++++++++++++------ dlt/destinations/impl/lancedb/utils.py | 26 +++++++++++++- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index a099974d56..4747856c73 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -77,6 +77,7 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, get_default_arrow_value, + create_unique_table_lineage, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -711,9 +712,8 @@ def create_table_chain_completed_followup_jobs( if table_chain[0].get("write_disposition") == "merge": all_job_paths_ordered = [ job.file_path - for table in table_chain + for _ in table_chain for job in completed_table_chain_jobs - if job.job_file_info.table_name == table.get("name") ] root_table_file_name = FileStorage.get_file_name_from_file_path( all_job_paths_ordered[0] @@ -791,22 +791,35 @@ def run(self) -> None: ) for file_path_ in self.references ] + table_lineage_unique: TTableLineage = create_unique_table_lineage(table_lineage) - for job in table_lineage: + for job in table_lineage_unique: target_is_root_table = "parent" not in job.table_schema fq_table_name = self._job_client.make_qualified_table_name(job.table_name) if target_is_root_table: target_table_id_field_name = "_dlt_id" - file_path = job.file_path + ancestors_file_paths = self.get_parent_paths(table_lineage, job.table_name) else: target_table_id_field_name = "_dlt_parent_id" - file_path = self.get_parent_path(table_lineage, job.table_schema.get("parent")) + ancestors_file_paths = self.get_parent_paths( + table_lineage, job.table_schema.get("parent") + ) - with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: - payload_arrow_table: pa.Table = pq.read_table(f) + # `when_not_matched_by_source_delete` removes absent source IDs. + # Loading ancestors individually risks unintended ID deletion, necessitating simultaneous loading of all ancestor IDs. + payload_arrow_table = None + for file_path_ in ancestors_file_paths: + with FileStorage.open_zipsafe_ro(file_path_, mode="rb") as f: + ancestor_arrow_table: pa.Table = pq.read_table(f) + if payload_arrow_table is None: + payload_arrow_table = ancestor_arrow_table + else: + payload_arrow_table = pa.concat_tables( + [payload_arrow_table, ancestor_arrow_table] + ) - # Get target table schema + # Get target table schema. with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: target_table_schema: pa.Schema = pq.read_schema(f) @@ -843,5 +856,7 @@ def run(self) -> None: ) @staticmethod - def get_parent_path(table_lineage: TTableLineage, table: str) -> Any: - return next(entry.file_path for entry in table_lineage if entry.table_name == table) + def get_parent_paths(table_lineage: TTableLineage, table: str) -> List[str]: + """Return all load files for a given table in the same order in which they were + loaded, thereby maintaining the load history of the table.""" + return [entry.file_path for entry in table_lineage if entry.table_name == table] diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index e1847120bd..624f3f4a4f 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -9,7 +9,7 @@ from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider -from dlt.destinations.impl.lancedb.schema import TArrowDataType +from dlt.destinations.impl.lancedb.schema import TArrowDataType, TTableLineage PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -93,3 +93,27 @@ def get_default_arrow_value(field_type: TArrowDataType) -> object: return datetime.now() else: raise ValueError(f"Unsupported data type: {field_type}") + + +def create_unique_table_lineage(table_lineage: TTableLineage) -> TTableLineage: + """Create a unique table lineage, keeping the last job for each table. + + Args: + table_lineage: The full table lineage. + + Returns: + A new list of TableJob objects with the duplicates removed, keeping the + last occurrence of each unique table name while maintaining the + original order of appearance. + """ + seen_table_names = set() + unique_lineage = [] + + for job in reversed(table_lineage): + if job.table_name not in seen_table_names: + seen_table_names.add(job.table_name) + unique_lineage.append(job) + + return list(reversed(unique_lineage)) + + From 43eb5b4aa87dfe05d5c839b894f297cab638af14 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 19:01:31 +0200 Subject: [PATCH 66/68] Fix doc_id adapter Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_adapter.py | 2 +- tests/load/lancedb/test_remove_orphaned_records.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_adapter.py b/dlt/destinations/impl/lancedb/lancedb_adapter.py index 2f6b44d131..0daba7a651 100644 --- a/dlt/destinations/impl/lancedb/lancedb_adapter.py +++ b/dlt/destinations/impl/lancedb/lancedb_adapter.py @@ -56,7 +56,7 @@ def lancedb_adapter( if document_id: if isinstance(document_id, str): - embed = [document_id] + document_id = [document_id] if not isinstance(document_id, list): raise ValueError( "'document_id' must be a list of column names or a single column name as a string." diff --git a/tests/load/lancedb/test_remove_orphaned_records.py b/tests/load/lancedb/test_remove_orphaned_records.py index 3118b06cc7..6346171574 100644 --- a/tests/load/lancedb/test_remove_orphaned_records.py +++ b/tests/load/lancedb/test_remove_orphaned_records.py @@ -198,11 +198,10 @@ def identity_resource( def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None: - @dlt.resource( # type: ignore + @dlt.resource( write_disposition="merge", table_name="document", merge_key=["chunk"], - columns={"doc_id": {DOCUMENT_ID_HINT: True}}, ) def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]: for doc in docs: From 40a5e7372abe24f496cffa33bfc7a10863cd0faf Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Mon, 26 Aug 2024 19:13:44 +0200 Subject: [PATCH 67/68] Revert to previous Signed-off-by: Marcel Coetzee --- dlt/destinations/impl/lancedb/lancedb_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 4747856c73..c2513c022b 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -712,8 +712,9 @@ def create_table_chain_completed_followup_jobs( if table_chain[0].get("write_disposition") == "merge": all_job_paths_ordered = [ job.file_path - for _ in table_chain + for table in table_chain for job in completed_table_chain_jobs + if job.job_file_info.table_name == table.get("name") ] root_table_file_name = FileStorage.get_file_name_from_file_path( all_job_paths_ordered[0] From 8cd6003f4e0bc8b821b83ed7515bf762f8280c14 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 27 Aug 2024 13:39:57 +0200 Subject: [PATCH 68/68] Revert "Remove orphans by loading all ancestor IDs simultaneously" This reverts commit 06e04d9fe6fbb76095129d93f5872c3c42566ca8. --- .../impl/lancedb/lancedb_client.py | 32 +++++-------------- dlt/destinations/impl/lancedb/utils.py | 26 +-------------- 2 files changed, 9 insertions(+), 49 deletions(-) diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index c2513c022b..a099974d56 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -77,7 +77,6 @@ set_non_standard_providers_environment_variables, generate_arrow_uuid_column, get_default_arrow_value, - create_unique_table_lineage, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.type_mapping import TypeMapper @@ -792,35 +791,22 @@ def run(self) -> None: ) for file_path_ in self.references ] - table_lineage_unique: TTableLineage = create_unique_table_lineage(table_lineage) - for job in table_lineage_unique: + for job in table_lineage: target_is_root_table = "parent" not in job.table_schema fq_table_name = self._job_client.make_qualified_table_name(job.table_name) if target_is_root_table: target_table_id_field_name = "_dlt_id" - ancestors_file_paths = self.get_parent_paths(table_lineage, job.table_name) + file_path = job.file_path else: target_table_id_field_name = "_dlt_parent_id" - ancestors_file_paths = self.get_parent_paths( - table_lineage, job.table_schema.get("parent") - ) + file_path = self.get_parent_path(table_lineage, job.table_schema.get("parent")) - # `when_not_matched_by_source_delete` removes absent source IDs. - # Loading ancestors individually risks unintended ID deletion, necessitating simultaneous loading of all ancestor IDs. - payload_arrow_table = None - for file_path_ in ancestors_file_paths: - with FileStorage.open_zipsafe_ro(file_path_, mode="rb") as f: - ancestor_arrow_table: pa.Table = pq.read_table(f) - if payload_arrow_table is None: - payload_arrow_table = ancestor_arrow_table - else: - payload_arrow_table = pa.concat_tables( - [payload_arrow_table, ancestor_arrow_table] - ) + with FileStorage.open_zipsafe_ro(file_path, mode="rb") as f: + payload_arrow_table: pa.Table = pq.read_table(f) - # Get target table schema. + # Get target table schema with FileStorage.open_zipsafe_ro(job.file_path, mode="rb") as f: target_table_schema: pa.Schema = pq.read_schema(f) @@ -857,7 +843,5 @@ def run(self) -> None: ) @staticmethod - def get_parent_paths(table_lineage: TTableLineage, table: str) -> List[str]: - """Return all load files for a given table in the same order in which they were - loaded, thereby maintaining the load history of the table.""" - return [entry.file_path for entry in table_lineage if entry.table_name == table] + def get_parent_path(table_lineage: TTableLineage, table: str) -> Any: + return next(entry.file_path for entry in table_lineage if entry.table_name == table) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py index 624f3f4a4f..e1847120bd 100644 --- a/dlt/destinations/impl/lancedb/utils.py +++ b/dlt/destinations/impl/lancedb/utils.py @@ -9,7 +9,7 @@ from dlt.common.schema import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.destinations.impl.lancedb.configuration import TEmbeddingProvider -from dlt.destinations.impl.lancedb.schema import TArrowDataType, TTableLineage +from dlt.destinations.impl.lancedb.schema import TArrowDataType PROVIDER_ENVIRONMENT_VARIABLES_MAP: Dict[TEmbeddingProvider, str] = { @@ -93,27 +93,3 @@ def get_default_arrow_value(field_type: TArrowDataType) -> object: return datetime.now() else: raise ValueError(f"Unsupported data type: {field_type}") - - -def create_unique_table_lineage(table_lineage: TTableLineage) -> TTableLineage: - """Create a unique table lineage, keeping the last job for each table. - - Args: - table_lineage: The full table lineage. - - Returns: - A new list of TableJob objects with the duplicates removed, keeping the - last occurrence of each unique table name while maintaining the - original order of appearance. - """ - seen_table_names = set() - unique_lineage = [] - - for job in reversed(table_lineage): - if job.table_name not in seen_table_names: - seen_table_names.add(job.table_name) - unique_lineage.append(job) - - return list(reversed(unique_lineage)) - -