From c07c8fcf8f74f05719d6e20d896d76ab3b55ae00 Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Thu, 29 Aug 2024 22:39:22 +0200 Subject: [PATCH] Spawn process in test to make sure registry can be deserialized from arrow files Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 34 +++++++++---------- dlt/destinations/impl/lancedb/models.py | 34 ------------------- tests/load/lancedb/test_pipeline.py | 18 +++++++--- 3 files changed, 31 insertions(+), 55 deletions(-) delete mode 100644 dlt/destinations/impl/lancedb/models.py diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 78a37952b9..f28fdac78d 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -1,7 +1,6 @@ import uuid from types import TracebackType from typing import ( - ClassVar, List, Any, cast, @@ -37,7 +36,6 @@ RunnableLoadJob, StorageSchemaInfo, StateInfo, - TLoadJobState, LoadJob, ) from dlt.common.pendulum import timedelta @@ -70,7 +68,6 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -81,6 +78,7 @@ TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} +EMPTY_STRING_PLACEHOLDER = "__EMPTY_STRING_PLACEHOLDER__" class LanceDBTypeMapper(TypeMapper): @@ -233,20 +231,11 @@ def __init__( embedding_model_provider, self.config.credentials.embedding_model_provider_api_key, ) - # Use the monkey-patched implementation if openai was chosen. - if embedding_model_provider == "openai": - from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings - - self.model_func = PatchedOpenAIEmbeddings( - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) - else: - self.model_func = self.registry.get(embedding_model_provider).create( - name=self.config.embedding_model, - max_retries=self.config.options.max_retries, - api_key=self.config.credentials.api_key, - ) + self.model_func = self.registry.get(embedding_model_provider).create( + name=self.config.embedding_model, + max_retries=self.config.options.max_retries, + api_key=self.config.credentials.api_key, + ) self.vector_field_name = self.config.vector_field_name self.id_field_name = self.config.id_field_name @@ -731,6 +720,17 @@ def run(self) -> None: with FileStorage.open_zipsafe_ro(self._file_path) as f: records: List[DictStrAny] = [json.loads(line) for line in f] + # Replace empty strings with placeholder string if OpenAI is used. + # https://github.com/lancedb/lancedb/issues/1577#issuecomment-2318104218. + if (self._job_client.config.embedding_model_provider == "openai") and ( + source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) + ): + records: List[Dict[str, Any]] + for record in records: + for k, v in record.items(): + if k in source_columns and not v: + record[k] = EMPTY_STRING_PLACEHOLDER + if self._load_table not in self._schema.dlt_tables(): for record in records: # Add reserved ID fields. diff --git a/dlt/destinations/impl/lancedb/models.py b/dlt/destinations/impl/lancedb/models.py deleted file mode 100644 index d90adb62bd..0000000000 --- a/dlt/destinations/impl/lancedb/models.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Union, List - -import numpy as np -from lancedb.embeddings import OpenAIEmbeddings # type: ignore -from lancedb.embeddings.registry import register # type: ignore -from lancedb.embeddings.utils import TEXT # type: ignore - - -@register("openai_patched") -class PatchedOpenAIEmbeddings(OpenAIEmbeddings): - EMPTY_STRING_PLACEHOLDER: str = "___EMPTY___" - - def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]: # type: ignore[type-arg] - """ - Replace empty strings with a placeholder value. - """ - - sanitized_texts = super().sanitize_input(texts) - return [self.EMPTY_STRING_PLACEHOLDER if item == "" else item for item in sanitized_texts] - - def generate_embeddings( - self, - texts: Union[List[str], np.ndarray], # type: ignore[type-arg] - ) -> List[np.array]: # type: ignore[valid-type] - """ - Generate embeddings, treating the placeholder as an empty result. - """ - embeddings: List[np.array] = super().generate_embeddings(texts) # type: ignore[valid-type] - - for i, text in enumerate(texts): - if text == self.EMPTY_STRING_PLACEHOLDER: - embeddings[i] = np.zeros(self.ndims()) - - return embeddings diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 3904dcdb1a..728127f833 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,3 +1,4 @@ +import multiprocessing from typing import Iterator, Generator, Any, List, Mapping import lancedb # type: ignore @@ -536,17 +537,26 @@ def search_data_resource() -> Generator[Mapping[str, object], Any, None]: db_client_uri = client.db_client.uri table_name = client.make_qualified_table_name("search_data_resource") - # A new python process doesn't seem to correctly deserialize the custom embedding functions into global __REGISTRY__. - EmbeddingFunctionRegistry.get_instance().reset() + # A new python process doesn't seem to correctly deserialize the custom embedding + # functions into global __REGISTRY__. + # We make sure to reset it as well to make sure no globals are propagated to the spawned process. + EmbeddingFunctionRegistry().reset() + with multiprocessing.get_context("spawn").Pool(1) as pool: + results = pool.apply(run_lance_search_in_separate_process, (db_client_uri, table_name)) + + assert results[0]["text"] == "Frodo was a happy puppy" + + +def run_lance_search_in_separate_process(db_client_uri: str, table_name: str) -> Any: + import lancedb # Must read into __REGISTRY__ here. db = lancedb.connect(db_client_uri) tbl = db[table_name] tbl.checkout_latest() - results = ( + return ( tbl.search("puppy", query_type="vector", ordering_field_name="_distance") .select(["text"]) .to_list() ) - assert results[0]["text"] == "Frodo was a happy puppy"