From dd973c5d56edb7f6478f719d7cfe8c0e45b8939e Mon Sep 17 00:00:00 2001 From: Marcel Coetzee <34739235+Pipboyguy@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:20:50 +0200 Subject: [PATCH] Don't use Custom Embedding Functions (#1771) * - Change default vector column name to "vector" to conform with lancedb standard - Add search tests with tantivy as search engine Signed-off-by: Marcel Coetzee * Format and fix linting Signed-off-by: Marcel Coetzee * Add custom embedding function registration test Signed-off-by: Marcel Coetzee * Spawn process in test to make sure registry can be deserialized from arrow files Signed-off-by: Marcel Coetzee * Simplify null string handling Signed-off-by: Marcel Coetzee * Change NULL string replacement with random string, doc clarification Signed-off-by: Marcel Coetzee * Update default vector column name in docs Signed-off-by: Marcel Coetzee --------- Signed-off-by: Marcel Coetzee --- .../impl/lancedb/configuration.py | 2 +- .../impl/lancedb/lancedb_client.py | 36 ++--- dlt/destinations/impl/lancedb/models.py | 34 ----- .../dlt-ecosystem/destinations/lancedb.md | 12 +- poetry.lock | 44 +++++- pyproject.toml | 3 +- tests/load/lancedb/test_pipeline.py | 131 +++++++++++++++++- tests/load/lancedb/utils.py | 2 +- 8 files changed, 204 insertions(+), 60 deletions(-) delete mode 100644 dlt/destinations/impl/lancedb/models.py diff --git a/dlt/destinations/impl/lancedb/configuration.py b/dlt/destinations/impl/lancedb/configuration.py index ba3a8b49d9..329132f495 100644 --- a/dlt/destinations/impl/lancedb/configuration.py +++ b/dlt/destinations/impl/lancedb/configuration.py @@ -90,7 +90,7 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration): but it is configurable in rare cases. Make sure it corresponds with the associated embedding model's dimensionality.""" - vector_field_name: str = "vector__" + vector_field_name: str = "vector" """Name of the special field to store the vector embeddings.""" id_field_name: str = "id__" """Name of the special field to manage deduplication.""" diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 02240b8f93..8d4b6303ef 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,7 +1,6 @@ import uuid from types import TracebackType from typing import ( - ClassVar, List, Any, cast, @@ -37,7 +36,6 @@ RunnableLoadJob, StorageSchemaInfo, StateInfo, - TLoadJobState, LoadJob, ) from dlt.common.pendulum import timedelta @@ -70,7 +68,6 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -81,6 +78,7 @@ TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" class LanceDBTypeMapper(TypeMapper): @@ -239,20 +237,11 @@ def __init__( embedding_model_provider, self.config.credentials.embedding_model_provider_api_key, ) - # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider == "openai": - from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings - - self.model_func = PatchedOpenAIEmbeddings( - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) - else: - self.model_func = self.registry.get(embedding_model_provider).create( - name=self.config.embedding_model, - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) + self.model_func = self.registry.get(embedding_model_provider).create( + name=self.config.embedding_model, + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) self.vector_field_name = self.config.vector_field_name self.id_field_name = self.config.id_field_name @@ -737,6 +726,19 @@ def run(self) -> None: with FileStorage.open_zipsafe_ro(self._file_path) as f: records: List[DictStrAny] = [json.loads(line) for line in f] + # Replace empty strings with placeholder string if OpenAI is used. + # https://github.com/lancedb/lancedb/issues/1577#issuecomment-2318104218. + if (self._job_client.config.embedding_model_provider == "openai") and ( + source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) + ): + records = [ + { + k: EMPTY_STRING_PLACEHOLDER if k in source_columns and v in ("", None) else v + for k, v in record.items() + } + for record in records + ] + if self._load_table not in self._schema.dlt_tables(): for record in records: # Add reserved ID fields. diff --git a/dlt/destinations/impl/lancedb/models.py b/dlt/destinations/impl/lancedb/models.py deleted file mode 100644 index d90adb62bd..0000000000 --- a/dlt/destinations/impl/lancedb/models.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Union, List - -import numpy as np -from lancedb.embeddings import OpenAIEmbeddings # type: ignore -from lancedb.embeddings.registry import register # type: ignore -from lancedb.embeddings.utils import TEXT # type: ignore - - -@register("openai_patched") -class PatchedOpenAIEmbeddings(OpenAIEmbeddings): - EMPTY_STRING_PLACEHOLDER: str = "___EMPTY___" - - def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: # type: ignore[type-arg] - """ - Replace empty strings with a placeholder value. - """ - - sanitized_texts = super().sanitize_input(texts) - return [self.EMPTY_STRING_PLACEHOLDER if item == "" else item for item in sanitized_texts] - - def generate_embeddings( - self, - texts: Union[List[str], np.ndarray], # type: ignore[type-arg] - ) -> List[np.array]: # type: ignore[valid-type] - """ - Generate embeddings, treating the placeholder as an empty result. - """ - embeddings: List[np.array] = super().generate_embeddings(texts) # type: ignore[valid-type] - - for i, text in enumerate(texts): - if text == self.EMPTY_STRING_PLACEHOLDER: - embeddings[i] = np.zeros(self.ndims()) - - return embeddings diff --git a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md index 8b7f3854ee..0d726508e6 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/lancedb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/lancedb.md @@ -201,7 +201,7 @@ This is the default disposition. It will append the data to the existing data in ## Additional Destination Options - `dataset_separator`: The character used to separate the dataset name from table names. Defaults to "___". -- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector__". +- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector". - `id_field_name`: The name of the special field used for deduplication and merging. Defaults to "id__". - `max_retries`: The maximum number of retries for embedding operations. Set to 0 to disable retries. Defaults to 3. @@ -216,11 +216,21 @@ The LanceDB destination supports syncing of the `dlt` state. ## Current Limitations +### In-Memory Tables + Adding new fields to an existing LanceDB table requires loading the entire table data into memory as a PyArrow table. This is because PyArrow tables are immutable, so adding fields requires creating a new table with the updated schema. For huge tables, this may impact performance and memory usage since the full table must be loaded into memory to add the new fields. Keep these considerations in mind when working with large datasets and monitor memory usage if adding fields to sizable existing tables. +### Null string handling for OpenAI embeddings + +OpenAI embedding service doesn't accept empty string bodies. We deal with this by replacing empty strings with a placeholder that should be very semantically dissimilar to 99.9% of queries. + +If your source column (column which is embedded) has empty values, it is important to consider the impact of this. There might be a _slight_ change that semantic queries can hit these empty strings. + +We reported this issue to LanceDB: https://github.com/lancedb/lancedb/issues/1577. + diff --git a/poetry.lock b/poetry.lock index 230b354b97..1bfdb776a2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "about-time" @@ -8647,6 +8647,44 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tantivy" +version = "0.22.0" +description = "" +optional = true +python-versions = ">=3.8" +files = [ + {file = "tantivy-0.22.0-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:732ec74c4dd531253af4c14756b7650527f22c7fab244e83b42d76a0a1437219"}, + {file = "tantivy-0.22.0-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:bf1da07b7e1003af4260b1ef3c3db7cb05db1578606092a6ca7a3cff2a22858a"}, + {file = "tantivy-0.22.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689ed52985e914c531eadd8dd2df1b29f0fa684687b6026206dbdc57cf9297b2"}, + {file = "tantivy-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5f2885c8e98d1efcc4836c3e9d327d6ba2bc6b5e2cd8ac9b0356af18f571070"}, + {file = "tantivy-0.22.0-cp310-none-win_amd64.whl", hash = "sha256:4543cc72f4fec30f50fed5cd503c13d0da7cffda47648c7b72c1759103309e41"}, + {file = "tantivy-0.22.0-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:ec693abf38f229bc1361b0d34029a8bb9f3ee5bb956a3e745e0c4a66ea815bec"}, + {file = "tantivy-0.22.0-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:e385839badc12b81e38bf0a4d865ee7c3a992fea9f5ce4117adae89369e7d1eb"}, + {file = "tantivy-0.22.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6c097d94be1af106676c86c02b185f029484fdbd9a2b9f17cb980e840e7bdad"}, + {file = "tantivy-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c47a5cdec306ea8594cb6e7effd4b430932ebfd969f9e8f99e343adf56a79bc9"}, + {file = "tantivy-0.22.0-cp311-none-win_amd64.whl", hash = "sha256:ba0ca878ed025d79edd9c51cda80b0105be8facbaec180fea64a17b80c74e7db"}, + {file = "tantivy-0.22.0-cp312-cp312-macosx_10_7_x86_64.whl", hash = "sha256:925682f3acb65c85c2a5a5b131401b9f30c184ea68aa73a8cc7c2ea6115e8ae3"}, + {file = "tantivy-0.22.0-cp312-cp312-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d75760e45a329313001354d6ca415ff12d9d812343792ae133da6bfbdc4b04a5"}, + {file = "tantivy-0.22.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd909d122b5af457d955552c304f8d5d046aee7024c703c62652ad72af89f3c7"}, + {file = "tantivy-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99266ffb204721eb2bd5b3184aa87860a6cff51b4563f808f78fa22d85a8093"}, + {file = "tantivy-0.22.0-cp312-none-win_amd64.whl", hash = "sha256:9ed6b813a1e7769444e33979b46b470b2f4c62d983c2560ce9486fb9be1491c9"}, + {file = "tantivy-0.22.0-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:97eb05f8585f321dbc733b64e7e917d061dc70c572c623730b366c216540d149"}, + {file = "tantivy-0.22.0-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:cc74748b6b886475c12bf47c8814861b79f850fb8a528f37ae0392caae0f6f14"}, + {file = "tantivy-0.22.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7059c51c25148e07a20bd73efc8b51c015c220f141f3638489447b99229c8c0"}, + {file = "tantivy-0.22.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f88d05f55e2c3e581de70c5c7f46e94e5869d1c0fd48c5db33be7e56b6b88c9a"}, + {file = "tantivy-0.22.0-cp38-none-win_amd64.whl", hash = "sha256:09bf6de2fa08aac1a7133bee3631c1123de05130fd2991ceb101f2abac51b9d2"}, + {file = "tantivy-0.22.0-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:9de1a7497d377477dc09029c343eb9106c2c5fdb2e399f8dddd624cd9c7622a2"}, + {file = "tantivy-0.22.0-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:e81e47edd0faffb5ad20f52ae75c3a2ed680f836e72bc85c799688d3a2557502"}, + {file = "tantivy-0.22.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27333518dbc309299dafe79443ee80eede5526a489323cdb0506b95eb334f985"}, + {file = "tantivy-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c9452d05e42450be53a9a58a9cf13f9ff8d3605c73bdc38a34ce5e167a25d77"}, + {file = "tantivy-0.22.0-cp39-none-win_amd64.whl", hash = "sha256:51e4ec0d44637562bf23912d18d12850c4b3176c0719e7b019d43b59199a643c"}, + {file = "tantivy-0.22.0.tar.gz", hash = "sha256:dce07fa2910c94934aa3d96c91087936c24e4a5802d839625d67edc6d1c95e5c"}, +] + +[package.extras] +dev = ["nox"] + [[package]] name = "tblib" version = "2.0.0" @@ -9669,7 +9707,7 @@ duckdb = ["duckdb"] filesystem = ["botocore", "s3fs"] gcp = ["gcsfs", "google-cloud-bigquery", "grpcio"] gs = ["gcsfs"] -lancedb = ["lancedb", "pyarrow"] +lancedb = ["lancedb", "pyarrow", "tantivy"] motherduck = ["duckdb", "pyarrow"] mssql = ["pyodbc"] parquet = ["pyarrow"] @@ -9684,4 +9722,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "2b8d00f91f33a380b2399989dcac0d1d106d0bd2cd8865c5b7e27a19885753b5" +content-hash = "888e1760984e867fde690a1cca90330e255d69a8775c81020d003650def7ab4c" diff --git a/pyproject.toml b/pyproject.toml index d32285572f..1bdaf77b86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ databricks-sql-connector = {version = ">=2.9.3", optional = true} clickhouse-driver = { version = ">=0.2.7", optional = true } clickhouse-connect = { version = ">=0.7.7", optional = true } lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'", allow-prereleases = true } +tantivy = { version = ">= 0.22.0", optional = true } deltalake = { version = ">=0.19.0", optional = true } [tool.poetry.extras] @@ -105,7 +106,7 @@ qdrant = ["qdrant-client"] databricks = ["databricks-sql-connector"] clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"] dremio = ["pyarrow"] -lancedb = ["lancedb", "pyarrow"] +lancedb = ["lancedb", "pyarrow", "tantivy"] deltalake = ["deltalake", "pyarrow"] diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index e817a2f6c8..728127f833 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,6 +1,10 @@ -from typing import Iterator, Generator, Any, List +import multiprocessing +from typing import Iterator, Generator, Any, List, Mapping +import lancedb # type: ignore import pytest +from lancedb import DBConnection +from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore import dlt from dlt.common import json @@ -21,7 +25,7 @@ @pytest.fixture(autouse=True) -def drop_lancedb_data() -> Iterator[None]: +def drop_lancedb_data() -> Iterator[Any]: yield drop_active_pipeline_data() @@ -433,3 +437,126 @@ def test_empty_dataset_allowed() -> None: assert client.dataset_name is None assert client.sentinel_table == "dltSentinelTable" assert_table(pipe, "content", expected_items_count=3) + + +search_data = [ + {"text": "Frodo was a happy puppy"}, + {"text": "There are several kittens playing"}, +] + + +def test_fts_query() -> None: + @dlt.resource + def search_data_resource() -> Generator[Mapping[str, object], Any, None]: + yield from search_data + + pipeline = dlt.pipeline( + pipeline_name="test_fts_query", + destination="lancedb", + dataset_name=f"test_pipeline_append{uniq_id()}", + ) + info = pipeline.run( + search_data_resource(), + ) + assert_load_info(info) + + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore[assignment] + db_client: DBConnection = client.db_client + + table_name = client.make_qualified_table_name("search_data_resource") + tbl = db_client[table_name] + tbl.checkout_latest() + + tbl.create_fts_index("text") + results = tbl.search("kittens", query_type="fts").select(["text"]).to_list() + assert results[0]["text"] == "There are several kittens playing" + + +def test_semantic_query() -> None: + @dlt.resource + def search_data_resource() -> Generator[Mapping[str, object], Any, None]: + yield from search_data + + lancedb_adapter( + search_data_resource, + embed=["text"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_fts_query", + destination="lancedb", + dataset_name=f"test_pipeline_append{uniq_id()}", + ) + info = pipeline.run( + search_data_resource(), + ) + assert_load_info(info) + + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore[assignment] + db_client: DBConnection = client.db_client + + table_name = client.make_qualified_table_name("search_data_resource") + tbl = db_client[table_name] + tbl.checkout_latest() + + results = ( + tbl.search("puppy", query_type="vector", ordering_field_name="_distance") + .select(["text"]) + .to_list() + ) + assert results[0]["text"] == "Frodo was a happy puppy" + + +def test_semantic_query_custom_embedding_functions_registered() -> None: + """Test the LanceDB registry registered custom embedding functions defined in models, if any. + See: https://github.com/dlt-hub/dlt/issues/1765""" + + @dlt.resource + def search_data_resource() -> Generator[Mapping[str, object], Any, None]: + yield from search_data + + lancedb_adapter( + search_data_resource, + embed=["text"], + ) + + pipeline = dlt.pipeline( + pipeline_name="test_fts_query", + destination="lancedb", + dataset_name=f"test_pipeline_append{uniq_id()}", + ) + info = pipeline.run( + search_data_resource(), + ) + assert_load_info(info) + + client: LanceDBClient + with pipeline.destination_client() as client: # type: ignore[assignment] + db_client_uri = client.db_client.uri + table_name = client.make_qualified_table_name("search_data_resource") + + # A new python process doesn't seem to correctly deserialize the custom embedding + # functions into global __REGISTRY__. + # We make sure to reset it as well to make sure no globals are propagated to the spawned process. + EmbeddingFunctionRegistry().reset() + with multiprocessing.get_context("spawn").Pool(1) as pool: + results = pool.apply(run_lance_search_in_separate_process, (db_client_uri, table_name)) + + assert results[0]["text"] == "Frodo was a happy puppy" + + +def run_lance_search_in_separate_process(db_client_uri: str, table_name: str) -> Any: + import lancedb + + # Must read into __REGISTRY__ here. + db = lancedb.connect(db_client_uri) + tbl = db[table_name] + tbl.checkout_latest() + + return ( + tbl.search("puppy", query_type="vector", ordering_field_name="_distance") + .select(["text"]) + .to_list() + ) diff --git a/tests/load/lancedb/utils.py b/tests/load/lancedb/utils.py index dc3ea5304b..7431e895b7 100644 --- a/tests/load/lancedb/utils.py +++ b/tests/load/lancedb/utils.py @@ -52,7 +52,7 @@ def assert_table( "_dlt_id", "_dlt_load_id", dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__", - dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__", + dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector", ] objects_without_dlt_or_special_keys = [ {k: v for k, v in record.items() if k not in drop_keys} for record in records