Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make 'endpoint' parameter optional for DatabricksVectorSearch #17

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions libs/databricks/langchain_databricks/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,23 @@ class DatabricksVectorSearch(VectorStore):

Key init args — indexing params:

endpoint: The name of the Databricks Vector Search endpoint.
index_name: The name of the index to use. Format: "catalog.schema.index".
endpoint: The name of the Databricks Vector Search endpoint. If not specified,
the endpoint name is automatically inferred based on the index name.

.. note::

If you are using `databricks-vectorsearch` version < 0.35, the `endpoint` parameter
is required when initializing the vector store.

.. code-block:: python

vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>",
...
)

embedding: The embedding model.
Required for direct-access index or delta-sync index
with self-managed embeddings.
Expand Down Expand Up @@ -89,7 +104,6 @@ class DatabricksVectorSearch(VectorStore):
from langchain_databricks.vectorstores import DatabricksVectorSearch

vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>"
)

Expand All @@ -102,7 +116,6 @@ class DatabricksVectorSearch(VectorStore):
from langchain_openai import OpenAIEmbeddings

vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>",
embedding=OpenAIEmbeddings(),
text_column="document_content"
Expand Down Expand Up @@ -196,8 +209,8 @@ class DatabricksVectorSearch(VectorStore):

def __init__(
self,
endpoint: str,
index_name: str,
endpoint: Optional[str] = None,
embedding: Optional[Embeddings] = None,
text_column: Optional[str] = None,
columns: Optional[List[str]] = None,
Expand All @@ -212,7 +225,21 @@ def __init__(
"Please install it with `pip install databricks-vectorsearch`."
) from e

self.index = VectorSearchClient().get_index(endpoint, index_name)
try:
self.index = VectorSearchClient().get_index(
endpoint_name=endpoint, index_name=index_name
)
except Exception as e:
if endpoint is None and "Wrong vector search endpoint" in str(e):
Copy link
Collaborator

@harupy harupy Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does Wrong vector search endpoint come from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the exception raised from the VectorSearch client (< 0.35). It is not open-sourced or documented either unfortunately😔

raise ValueError(
"The `endpoint` parameter is required for instantiating "
"DatabricksVectorSearch with the `databricks-vectorsearch` "
"version earlier than 0.35. Please provide the endpoint "
"name or upgrade to version 0.35 or later."
) from e
else:
raise

self._index_details = IndexDetails(self.index)

_validate_embedding(embedding, self._index_details)
Expand Down
22 changes: 12 additions & 10 deletions libs/databricks/tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ def embed_query(self, text: str) -> List[float]:

@pytest.fixture(autouse=True)
def mock_vs_client() -> Generator:
def _get_index(endpoint: str, index_name: str) -> MagicMock:
def _get_index(
endpoint_name: Optional[str] = None,
index_name: str = None, # type: ignore
harupy marked this conversation as resolved.
Show resolved Hide resolved
) -> MagicMock:
from databricks.vector_search.client import VectorSearchIndex # type: ignore

if endpoint != ENDPOINT_NAME:
raise ValueError(f"Unknown endpoint: {endpoint}")

index = MagicMock(spec=VectorSearchIndex)
index.describe.return_value = INDEX_DETAILS[index_name]
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
Expand All @@ -157,7 +157,6 @@ def init_vector_search(
index_name: str, columns: Optional[List[str]] = None
) -> DatabricksVectorSearch:
kwargs: Dict[str, Any] = {
"endpoint": ENDPOINT_NAME,
"index_name": index_name,
"columns": columns,
}
Expand All @@ -177,10 +176,17 @@ def test_init(index_name: str) -> None:
assert vectorsearch.index.describe() == INDEX_DETAILS[index_name]


def test_init_with_endpoint_name() -> None:
vectorsearch = DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
)
assert vectorsearch.index.describe() == INDEX_DETAILS[DELTA_SYNC_INDEX]


def test_init_fail_text_column_mismatch() -> None:
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
text_column="some_other_column",
)
Expand All @@ -190,7 +196,6 @@ def test_init_fail_text_column_mismatch() -> None:
def test_init_fail_no_text_column(index_name: str) -> None:
with pytest.raises(ValueError, match="The `text_column` parameter is required"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
embedding=EMBEDDING_MODEL,
)
Expand All @@ -206,7 +211,6 @@ def test_init_fail_columns_not_in_schema() -> None:
def test_init_fail_no_embedding(index_name: str) -> None:
with pytest.raises(ValueError, match="The `embedding` parameter is required"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
text_column="text",
)
Expand All @@ -215,7 +219,6 @@ def test_init_fail_no_embedding(index_name: str) -> None:
def test_init_fail_embedding_already_specified_in_source() -> None:
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
embedding=EMBEDDING_MODEL,
)
Expand All @@ -227,7 +230,6 @@ def test_init_fail_embedding_dim_mismatch(index_name: str) -> None:
ValueError, match="embedding model's dimension '1000' does not match"
):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
text_column="text",
embedding=FakeEmbeddings(1000),
Expand Down
Loading