Skip to content

Commit

Permalink
Don't use Custom Embedding Functions (#1771)
Browse files Browse the repository at this point in the history
* - Change default vector column name to "vector" to conform with lancedb standard
- Add search tests with tantivy as search engine

Signed-off-by: Marcel Coetzee <[email protected]>

* Format and fix linting

Signed-off-by: Marcel Coetzee <[email protected]>

* Add custom embedding function registration test

Signed-off-by: Marcel Coetzee <[email protected]>

* Spawn process in test to make sure registry can be deserialized from arrow files

Signed-off-by: Marcel Coetzee <[email protected]>

* Simplify null string handling

Signed-off-by: Marcel Coetzee <[email protected]>

* Change NULL string replacement with random string, doc clarification

Signed-off-by: Marcel Coetzee <[email protected]>

* Update default vector column name in docs

Signed-off-by: Marcel Coetzee <[email protected]>

---------

Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy authored Sep 3, 2024
1 parent 36c0d14 commit dd973c5
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 60 deletions.
2 changes: 1 addition & 1 deletion dlt/destinations/impl/lancedb/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class LanceDBClientConfiguration(DestinationClientDwhConfiguration):
but it is configurable in rare cases.
Make sure it corresponds with the associated embedding model's dimensionality."""
vector_field_name: str = "vector__"
vector_field_name: str = "vector"
"""Name of the special field to store the vector embeddings."""
id_field_name: str = "id__"
"""Name of the special field to manage deduplication."""
Expand Down
36 changes: 19 additions & 17 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uuid
from types import TracebackType
from typing import (
ClassVar,
List,
Any,
cast,
Expand Down Expand Up @@ -37,7 +36,6 @@
RunnableLoadJob,
StorageSchemaInfo,
StateInfo,
TLoadJobState,
LoadJob,
)
from dlt.common.pendulum import timedelta
Expand Down Expand Up @@ -70,7 +68,6 @@
generate_uuid,
set_non_standard_providers_environment_variables,
)
from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs
from dlt.destinations.type_mapping import TypeMapper

if TYPE_CHECKING:
Expand All @@ -81,6 +78,7 @@

TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()}
EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB"


class LanceDBTypeMapper(TypeMapper):
Expand Down Expand Up @@ -239,20 +237,11 @@ def __init__(
embedding_model_provider,
self.config.credentials.embedding_model_provider_api_key,
)
# Use the monkey-patched implementation if openai was chosen.
if embedding_model_provider == "openai":
from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings

self.model_func = PatchedOpenAIEmbeddings(
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
)
else:
self.model_func = self.registry.get(embedding_model_provider).create(
name=self.config.embedding_model,
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
)
self.model_func = self.registry.get(embedding_model_provider).create(
name=self.config.embedding_model,
max_retries=self.config.options.max_retries,
api_key=self.config.credentials.api_key,
)

self.vector_field_name = self.config.vector_field_name
self.id_field_name = self.config.id_field_name
Expand Down Expand Up @@ -737,6 +726,19 @@ def run(self) -> None:
with FileStorage.open_zipsafe_ro(self._file_path) as f:
records: List[DictStrAny] = [json.loads(line) for line in f]

# Replace empty strings with placeholder string if OpenAI is used.
# https://github.com/lancedb/lancedb/issues/1577#issuecomment-2318104218.
if (self._job_client.config.embedding_model_provider == "openai") and (
source_columns := get_columns_names_with_prop(self._load_table, VECTORIZE_HINT)
):
records = [
{
k: EMPTY_STRING_PLACEHOLDER if k in source_columns and v in ("", None) else v
for k, v in record.items()
}
for record in records
]

if self._load_table not in self._schema.dlt_tables():
for record in records:
# Add reserved ID fields.
Expand Down
34 changes: 0 additions & 34 deletions dlt/destinations/impl/lancedb/models.py

This file was deleted.

12 changes: 11 additions & 1 deletion docs/website/docs/dlt-ecosystem/destinations/lancedb.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ This is the default disposition. It will append the data to the existing data in
## Additional Destination Options

- `dataset_separator`: The character used to separate the dataset name from table names. Defaults to "___".
- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector__".
- `vector_field_name`: The name of the special field to store vector embeddings. Defaults to "vector".
- `id_field_name`: The name of the special field used for deduplication and merging. Defaults to "id__".
- `max_retries`: The maximum number of retries for embedding operations. Set to 0 to disable retries. Defaults to 3.

Expand All @@ -216,11 +216,21 @@ The LanceDB destination supports syncing of the `dlt` state.

## Current Limitations

### In-Memory Tables

Adding new fields to an existing LanceDB table requires loading the entire table data into memory as a PyArrow table.
This is because PyArrow tables are immutable, so adding fields requires creating a new table with the updated schema.

For huge tables, this may impact performance and memory usage since the full table must be loaded into memory to add the new fields.
Keep these considerations in mind when working with large datasets and monitor memory usage if adding fields to sizable existing tables.

### Null string handling for OpenAI embeddings

OpenAI embedding service doesn't accept empty string bodies. We deal with this by replacing empty strings with a placeholder that should be very semantically dissimilar to 99.9% of queries.

If your source column (column which is embedded) has empty values, it is important to consider the impact of this. There might be a _slight_ change that semantic queries can hit these empty strings.

We reported this issue to LanceDB: https://github.com/lancedb/lancedb/issues/1577.

<!--@@@DLT_TUBA lancedb-->

44 changes: 41 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ databricks-sql-connector = {version = ">=2.9.3", optional = true}
clickhouse-driver = { version = ">=0.2.7", optional = true }
clickhouse-connect = { version = ">=0.7.7", optional = true }
lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= '3.9'", allow-prereleases = true }
tantivy = { version = ">= 0.22.0", optional = true }
deltalake = { version = ">=0.19.0", optional = true }

[tool.poetry.extras]
Expand All @@ -105,7 +106,7 @@ qdrant = ["qdrant-client"]
databricks = ["databricks-sql-connector"]
clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"]
dremio = ["pyarrow"]
lancedb = ["lancedb", "pyarrow"]
lancedb = ["lancedb", "pyarrow", "tantivy"]
deltalake = ["deltalake", "pyarrow"]


Expand Down
131 changes: 129 additions & 2 deletions tests/load/lancedb/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Iterator, Generator, Any, List
import multiprocessing
from typing import Iterator, Generator, Any, List, Mapping

import lancedb # type: ignore
import pytest
from lancedb import DBConnection
from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore

import dlt
from dlt.common import json
Expand All @@ -21,7 +25,7 @@


@pytest.fixture(autouse=True)
def drop_lancedb_data() -> Iterator[None]:
def drop_lancedb_data() -> Iterator[Any]:
yield
drop_active_pipeline_data()

Expand Down Expand Up @@ -433,3 +437,126 @@ def test_empty_dataset_allowed() -> None:
assert client.dataset_name is None
assert client.sentinel_table == "dltSentinelTable"
assert_table(pipe, "content", expected_items_count=3)


search_data = [
{"text": "Frodo was a happy puppy"},
{"text": "There are several kittens playing"},
]


def test_fts_query() -> None:
@dlt.resource
def search_data_resource() -> Generator[Mapping[str, object], Any, None]:
yield from search_data

pipeline = dlt.pipeline(
pipeline_name="test_fts_query",
destination="lancedb",
dataset_name=f"test_pipeline_append{uniq_id()}",
)
info = pipeline.run(
search_data_resource(),
)
assert_load_info(info)

client: LanceDBClient
with pipeline.destination_client() as client: # type: ignore[assignment]
db_client: DBConnection = client.db_client

table_name = client.make_qualified_table_name("search_data_resource")
tbl = db_client[table_name]
tbl.checkout_latest()

tbl.create_fts_index("text")
results = tbl.search("kittens", query_type="fts").select(["text"]).to_list()
assert results[0]["text"] == "There are several kittens playing"


def test_semantic_query() -> None:
@dlt.resource
def search_data_resource() -> Generator[Mapping[str, object], Any, None]:
yield from search_data

lancedb_adapter(
search_data_resource,
embed=["text"],
)

pipeline = dlt.pipeline(
pipeline_name="test_fts_query",
destination="lancedb",
dataset_name=f"test_pipeline_append{uniq_id()}",
)
info = pipeline.run(
search_data_resource(),
)
assert_load_info(info)

client: LanceDBClient
with pipeline.destination_client() as client: # type: ignore[assignment]
db_client: DBConnection = client.db_client

table_name = client.make_qualified_table_name("search_data_resource")
tbl = db_client[table_name]
tbl.checkout_latest()

results = (
tbl.search("puppy", query_type="vector", ordering_field_name="_distance")
.select(["text"])
.to_list()
)
assert results[0]["text"] == "Frodo was a happy puppy"


def test_semantic_query_custom_embedding_functions_registered() -> None:
"""Test the LanceDB registry registered custom embedding functions defined in models, if any.
See: https://github.com/dlt-hub/dlt/issues/1765"""

@dlt.resource
def search_data_resource() -> Generator[Mapping[str, object], Any, None]:
yield from search_data

lancedb_adapter(
search_data_resource,
embed=["text"],
)

pipeline = dlt.pipeline(
pipeline_name="test_fts_query",
destination="lancedb",
dataset_name=f"test_pipeline_append{uniq_id()}",
)
info = pipeline.run(
search_data_resource(),
)
assert_load_info(info)

client: LanceDBClient
with pipeline.destination_client() as client: # type: ignore[assignment]
db_client_uri = client.db_client.uri
table_name = client.make_qualified_table_name("search_data_resource")

# A new python process doesn't seem to correctly deserialize the custom embedding
# functions into global __REGISTRY__.
# We make sure to reset it as well to make sure no globals are propagated to the spawned process.
EmbeddingFunctionRegistry().reset()
with multiprocessing.get_context("spawn").Pool(1) as pool:
results = pool.apply(run_lance_search_in_separate_process, (db_client_uri, table_name))

assert results[0]["text"] == "Frodo was a happy puppy"


def run_lance_search_in_separate_process(db_client_uri: str, table_name: str) -> Any:
import lancedb

# Must read into __REGISTRY__ here.
db = lancedb.connect(db_client_uri)
tbl = db[table_name]
tbl.checkout_latest()

return (
tbl.search("puppy", query_type="vector", ordering_field_name="_distance")
.select(["text"])
.to_list()
)
2 changes: 1 addition & 1 deletion tests/load/lancedb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def assert_table(
"_dlt_id",
"_dlt_load_id",
dlt.config.get("destination.lancedb.credentials.id_field_name", str) or "id__",
dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector__",
dlt.config.get("destination.lancedb.credentials.vector_field_name", str) or "vector",
]
objects_without_dlt_or_special_keys = [
{k: v for k, v in record.items() if k not in drop_keys} for record in records
Expand Down

0 comments on commit dd973c5

Please sign in to comment.