Skip to content

Commit

Permalink
Spawn process in test to make sure registry can be deserialized from …
Browse files Browse the repository at this point in the history
…arrow files

Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Aug 29, 2024
1 parent 703c4a8 commit c07c8fc
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 55 deletions.
34 changes: 17 additions & 17 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uuid
from types import TracebackType
from typing import (
ClassVar,
List,
Any,
cast,
Expand Down Expand Up @@ -37,7 +36,6 @@
RunnableLoadJob,
StorageSchemaInfo,
StateInfo,
TLoadJobState,
LoadJob,
)
from dlt.common.pendulum import timedelta
Expand Down Expand Up @@ -70,7 +68,6 @@
generate_uuid,
set_non_standard_providers_environment_variables,
)
from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs
from dlt.destinations.type_mapping import TypeMapper

if TYPE_CHECKING:
Expand All @@ -81,6 +78,7 @@

TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()}
EMPTY_STRING_PLACEHOLDER = "__EMPTY_STRING_PLACEHOLDER__"


class LanceDBTypeMapper(TypeMapper):
Expand Down Expand Up @@ -233,20 +231,11 @@ def __init__(
embedding_model_provider,
self.config.credentials.embedding_model_provider_api_key,
)
# Use the monkey-patched implementation if openai was chosen.
if embedding_model_provider == "openai":
from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings

self.model_func = PatchedOpenAIEmbeddings(
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
)
else:
self.model_func = self.registry.get(embedding_model_provider).create(
name=self.config.embedding_model,
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
)
self.model_func = self.registry.get(embedding_model_provider).create(
name=self.config.embedding_model,
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
)

self.vector_field_name = self.config.vector_field_name
self.id_field_name = self.config.id_field_name
Expand Down Expand Up @@ -731,6 +720,17 @@ def run(self) -> None:
with FileStorage.open_zipsafe_ro(self._file_path) as f:
records: List[DictStrAny] = [json.loads(line) for line in f]

# Replace empty strings with placeholder string if OpenAI is used.
# https://github.com/lancedb/lancedb/issues/1577#issuecomment-2318104218.
if (self._job_client.config.embedding_model_provider == "openai") and (
source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT)
):
records: List[Dict[str, Any]]
for record in records:
for k, v in record.items():
if k in source_columns and not v:
record[k] = EMPTY_STRING_PLACEHOLDER

if self._load_table not in self._schema.dlt_tables():
for record in records:
# Add reserved ID fields.
Expand Down
34 changes: 0 additions & 34 deletions dlt/destinations/impl/lancedb/models.py

This file was deleted.

18 changes: 14 additions & 4 deletions tests/load/lancedb/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
from typing import Iterator, Generator, Any, List, Mapping

import lancedb # type: ignore
Expand Down Expand Up @@ -536,17 +537,26 @@ def search_data_resource() -> Generator[Mapping[str, object], Any, None]:
db_client_uri = client.db_client.uri
table_name = client.make_qualified_table_name("search_data_resource")

# A new python process doesn't seem to correctly deserialize the custom embedding functions into global __REGISTRY__.
EmbeddingFunctionRegistry.get_instance().reset()
# A new python process doesn't seem to correctly deserialize the custom embedding
# functions into global __REGISTRY__.
# We make sure to reset it as well to make sure no globals are propagated to the spawned process.
EmbeddingFunctionRegistry().reset()
with multiprocessing.get_context("spawn").Pool(1) as pool:
results = pool.apply(run_lance_search_in_separate_process, (db_client_uri, table_name))

assert results[0]["text"] == "Frodo was a happy puppy"


def run_lance_search_in_separate_process(db_client_uri: str, table_name: str) -> Any:
import lancedb

# Must read into __REGISTRY__ here.
db = lancedb.connect(db_client_uri)
tbl = db[table_name]
tbl.checkout_latest()

results = (
return (
tbl.search("puppy", query_type="vector", ordering_field_name="_distance")
.select(["text"])
.to_list()
)
assert results[0]["text"] == "Frodo was a happy puppy"

0 comments on commit c07c8fc

Please sign in to comment.