Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't use Custom Embedding Functions #1771

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dlt/destinations/impl/lancedb/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration):
but it is configurable in rare cases.

Make sure it corresponds with the associated embedding model's dimensionality."""
vector_field_name: str = "vector__"
vector_field_name: str = "vector"
"""Name of the special field to store the vector embeddings."""
id_field_name: str = "id__"
"""Name of the special field to manage deduplication."""
Expand Down
36 changes: 19 additions & 17 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uuid
from types import TracebackType
from typing import (
ClassVar,
List,
Any,
cast,
Expand Down Expand Up @@ -37,7 +36,6 @@
RunnableLoadJob,
StorageSchemaInfo,
StateInfo,
TLoadJobState,
LoadJob,
)
from dlt.common.pendulum import timedelta
Expand Down Expand Up @@ -70,7 +68,6 @@
generate_uuid,
set_non_standard_providers_environment_variables,
)
from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs
from dlt.destinations.type_mapping import TypeMapper

if TYPE_CHECKING:
Expand All @@ -81,6 +78,7 @@

TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()}
EMPTY_STRING_PLACEHOLDER = "__EMPTY_STRING_PLACEHOLDER__"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use some random string. who knows what kind of tokenizer may be used against it... openAI may embed this as separate words

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh good point! You're right I'll replace with randomly gen string



class LanceDBTypeMapper(TypeMapper):
Expand Down Expand Up @@ -233,20 +231,11 @@ def __init__(
embedding_model_provider,
self.config.credentials.embedding_model_provider_api_key,
)
# Use the monkey-patched implementation if openai was chosen.
if embedding_model_provider == "openai":
from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings

self.model_func = PatchedOpenAIEmbeddings(
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
)
else:
self.model_func = self.registry.get(embedding_model_provider).create(
name=self.config.embedding_model,
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
)
self.model_func = self.registry.get(embedding_model_provider).create(
name=self.config.embedding_model,
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
)

self.vector_field_name = self.config.vector_field_name
self.id_field_name = self.config.id_field_name
Expand Down Expand Up @@ -731,6 +720,19 @@ def run(self) -> None:
with FileStorage.open_zipsafe_ro(self._file_path) as f:
records: List[DictStrAny] = [json.loads(line) for line in f]

# Replace empty strings with placeholder string if OpenAI is used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't tell the impact on performance, but I think it's a good fix until there's progress on the LanceDB issue!

I don't know how frequently you'd hit an empty string when embedding, but it might be worth mentioning in the docs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pipboyguy didn't we switch the format to parquet? I think it is in PR that is still in review. anyway we'll be able to use pa.compute to replace those soon

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rudolfix yes indeed, it does make it a bit tricky to implement a fix considering the switch in format.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zilto agreed, will add a doc entry for this!

# https://github.com/lancedb/lancedb/issues/1577#issuecomment-2318104218.
if (self._job_client.config.embedding_model_provider == "openai") and (
source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT)
):
records = [
{
k: EMPTY_STRING_PLACEHOLDER if k in source_columns and v in ("", None) else v
for k, v in record.items()
}
for record in records
]

if self._load_table not in self._schema.dlt_tables():
for record in records:
# Add reserved ID fields.
Expand Down
34 changes: 0 additions & 34 deletions dlt/destinations/impl/lancedb/models.py

This file was deleted.

44 changes: 41 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ databricks-sql-connector = {version = ">=2.9.3", optional = true}
clickhouse-driver = { version = ">=0.2.7", optional = true }
clickhouse-connect = { version = ">=0.7.7", optional = true }
lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'", allow-prereleases = true }
tantivy = { version = ">= 0.22.0", optional = true }
deltalake = { version = ">=0.19.0", optional = true }

[tool.poetry.extras]
Expand All @@ -105,7 +106,7 @@ qdrant = ["qdrant-client"]
databricks = ["databricks-sql-connector"]
clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"]
dremio = ["pyarrow"]
lancedb = ["lancedb", "pyarrow"]
lancedb = ["lancedb", "pyarrow", "tantivy"]
deltalake = ["deltalake", "pyarrow"]


Expand Down
131 changes: 129 additions & 2 deletions tests/load/lancedb/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Iterator, Generator, Any, List
import multiprocessing
from typing import Iterator, Generator, Any, List, Mapping

import lancedb # type: ignore
import pytest
from lancedb import DBConnection
from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore

import dlt
from dlt.common import json
Expand All @@ -21,7 +25,7 @@


@pytest.fixture(autouse=True)
def drop_lancedb_data() -> Iterator[None]:
def drop_lancedb_data() -> Iterator[Any]:
yield
drop_active_pipeline_data()

Expand Down Expand Up @@ -433,3 +437,126 @@ def test_empty_dataset_allowed() -> None:
assert client.dataset_name is None
assert client.sentinel_table == "dltSentinelTable"
assert_table(pipe, "content", expected_items_count=3)


search_data = [
{"text": "Frodo was a happy puppy"},
{"text": "There are several kittens playing"},
]


def test_fts_query() -> None:
@dlt.resource
def search_data_resource() -> Generator[Mapping[str, object], Any, None]:
yield from search_data

pipeline = dlt.pipeline(
pipeline_name="test_fts_query",
destination="lancedb",
dataset_name=f"test_pipeline_append{uniq_id()}",
)
info = pipeline.run(
search_data_resource(),
)
assert_load_info(info)

client: LanceDBClient
with pipeline.destination_client() as client: # type: ignore[assignment]
db_client: DBConnection = client.db_client

table_name = client.make_qualified_table_name("search_data_resource")
tbl = db_client[table_name]
tbl.checkout_latest()

tbl.create_fts_index("text")
results = tbl.search("kittens", query_type="fts").select(["text"]).to_list()
assert results[0]["text"] == "There are several kittens playing"


def test_semantic_query() -> None:
@dlt.resource
def search_data_resource() -> Generator[Mapping[str, object], Any, None]:
yield from search_data

lancedb_adapter(
search_data_resource,
embed=["text"],
)

pipeline = dlt.pipeline(
pipeline_name="test_fts_query",
destination="lancedb",
dataset_name=f"test_pipeline_append{uniq_id()}",
)
info = pipeline.run(
search_data_resource(),
)
assert_load_info(info)

client: LanceDBClient
with pipeline.destination_client() as client: # type: ignore[assignment]
db_client: DBConnection = client.db_client

table_name = client.make_qualified_table_name("search_data_resource")
tbl = db_client[table_name]
tbl.checkout_latest()

results = (
tbl.search("puppy", query_type="vector", ordering_field_name="_distance")
.select(["text"])
.to_list()
)
assert results[0]["text"] == "Frodo was a happy puppy"


def test_semantic_query_custom_embedding_functions_registered() -> None:
"""Test the LanceDB registry registered custom embedding functions defined in models, if any.
See: https://github.com/dlt-hub/dlt/issues/1765"""

@dlt.resource
def search_data_resource() -> Generator[Mapping[str, object], Any, None]:
yield from search_data

lancedb_adapter(
search_data_resource,
embed=["text"],
)

pipeline = dlt.pipeline(
pipeline_name="test_fts_query",
destination="lancedb",
dataset_name=f"test_pipeline_append{uniq_id()}",
)
info = pipeline.run(
search_data_resource(),
)
assert_load_info(info)

client: LanceDBClient
with pipeline.destination_client() as client: # type: ignore[assignment]
db_client_uri = client.db_client.uri
table_name = client.make_qualified_table_name("search_data_resource")

# A new python process doesn't seem to correctly deserialize the custom embedding
# functions into global __REGISTRY__.
# We make sure to reset it as well to make sure no globals are propagated to the spawned process.
EmbeddingFunctionRegistry().reset()
with multiprocessing.get_context("spawn").Pool(1) as pool:
results = pool.apply(run_lance_search_in_separate_process, (db_client_uri, table_name))

assert results[0]["text"] == "Frodo was a happy puppy"


def run_lance_search_in_separate_process(db_client_uri: str, table_name: str) -> Any:
import lancedb

# Must read into __REGISTRY__ here.
db = lancedb.connect(db_client_uri)
tbl = db[table_name]
tbl.checkout_latest()

return (
tbl.search("puppy", query_type="vector", ordering_field_name="_distance")
.select(["text"])
.to_list()
)
2 changes: 1 addition & 1 deletion tests/load/lancedb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def assert_table(
"_dlt_id",
"_dlt_load_id",
dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__",
dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__",
dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using vector is nice because it aligns with the lancedb defaults.

]
objects_without_dlt_or_special_keys = [
{k: v for k, v in record.items() if k not in drop_keys} for record in records
Expand Down
Loading