From 4380112ec16c762f46fd5f4116d70443a3bf0815 Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Fri, 20 Sep 2024 17:10:29 +0900 Subject: [PATCH 1/2] Make 'endpoint' parameter optional for DatabricksVectorSearch Signed-off-by: B-Step62 --- .../langchain_databricks/vectorstores.py | 33 ++++++++++++++++--- .../tests/unit_tests/test_vectorstore.py | 22 +++++++------ 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/libs/databricks/langchain_databricks/vectorstores.py b/libs/databricks/langchain_databricks/vectorstores.py index 7359dcf..c2c01a1 100644 --- a/libs/databricks/langchain_databricks/vectorstores.py +++ b/libs/databricks/langchain_databricks/vectorstores.py @@ -89,7 +89,6 @@ class DatabricksVectorSearch(VectorStore): from langchain_databricks.vectorstores import DatabricksVectorSearch vector_store = DatabricksVectorSearch( - endpoint="", index_name="" ) @@ -102,12 +101,24 @@ class DatabricksVectorSearch(VectorStore): from langchain_openai import OpenAIEmbeddings vector_store = DatabricksVectorSearch( - endpoint="", index_name="", embedding=OpenAIEmbeddings(), text_column="document_content" ) + .. note:: + + If you are using `databricks-vectorsearch` version earlier than 0.35, you also need to + provide the `endpoint` parameter when initializing the vector store. + + .. code-block:: python + + vector_store = DatabricksVectorSearch( + endpoint="", + index_name="", + ... + ) + Add Documents: .. code-block:: python from langchain_core.documents import Document @@ -196,8 +207,8 @@ class DatabricksVectorSearch(VectorStore): def __init__( self, - endpoint: str, index_name: str, + endpoint: Optional[str] = None, embedding: Optional[Embeddings] = None, text_column: Optional[str] = None, columns: Optional[List[str]] = None, @@ -212,7 +223,21 @@ def __init__( "Please install it with `pip install databricks-vectorsearch`." ) from e - self.index = VectorSearchClient().get_index(endpoint, index_name) + try: + self.index = VectorSearchClient().get_index( + endpoint_name=endpoint, index_name=index_name + ) + except Exception as e: + if endpoint is None and "Wrong vector search endpoint" in str(e): + raise ValueError( + "The `endpoint` parameter is required for instantiating " + "DatabricksVectorSearch with the `databricks-vectorsearch` " + "version earlier than 0.35. Please provide the endpoint " + "name or upgrade to version 0.35 or later." + ) from e + else: + raise + self._index_details = IndexDetails(self.index) _validate_embedding(embedding, self._index_details) diff --git a/libs/databricks/tests/unit_tests/test_vectorstore.py b/libs/databricks/tests/unit_tests/test_vectorstore.py index ed8654e..164cd5a 100644 --- a/libs/databricks/tests/unit_tests/test_vectorstore.py +++ b/libs/databricks/tests/unit_tests/test_vectorstore.py @@ -133,12 +133,12 @@ def embed_query(self, text: str) -> List[float]: @pytest.fixture(autouse=True) def mock_vs_client() -> Generator: - def _get_index(endpoint: str, index_name: str) -> MagicMock: + def _get_index( + endpoint_name: Optional[str] = None, + index_name: str = None, # type: ignore + ) -> MagicMock: from databricks.vector_search.client import VectorSearchIndex # type: ignore - if endpoint != ENDPOINT_NAME: - raise ValueError(f"Unknown endpoint: {endpoint}") - index = MagicMock(spec=VectorSearchIndex) index.describe.return_value = INDEX_DETAILS[index_name] index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE @@ -157,7 +157,6 @@ def init_vector_search( index_name: str, columns: Optional[List[str]] = None ) -> DatabricksVectorSearch: kwargs: Dict[str, Any] = { - "endpoint": ENDPOINT_NAME, "index_name": index_name, "columns": columns, } @@ -177,10 +176,17 @@ def test_init(index_name: str) -> None: assert vectorsearch.index.describe() == INDEX_DETAILS[index_name] +def test_init_with_endpoint_name() -> None: + vectorsearch = DatabricksVectorSearch( + endpoint=ENDPOINT_NAME, + index_name=DELTA_SYNC_INDEX, + ) + assert vectorsearch.index.describe() == INDEX_DETAILS[DELTA_SYNC_INDEX] + + def test_init_fail_text_column_mismatch() -> None: with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=DELTA_SYNC_INDEX, text_column="some_other_column", ) @@ -190,7 +196,6 @@ def test_init_fail_text_column_mismatch() -> None: def test_init_fail_no_text_column(index_name: str) -> None: with pytest.raises(ValueError, match="The `text_column` parameter is required"): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=index_name, embedding=EMBEDDING_MODEL, ) @@ -206,7 +211,6 @@ def test_init_fail_columns_not_in_schema() -> None: def test_init_fail_no_embedding(index_name: str) -> None: with pytest.raises(ValueError, match="The `embedding` parameter is required"): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=index_name, text_column="text", ) @@ -215,7 +219,6 @@ def test_init_fail_no_embedding(index_name: str) -> None: def test_init_fail_embedding_already_specified_in_source() -> None: with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=DELTA_SYNC_INDEX, embedding=EMBEDDING_MODEL, ) @@ -227,7 +230,6 @@ def test_init_fail_embedding_dim_mismatch(index_name: str) -> None: ValueError, match="embedding model's dimension '1000' does not match" ): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=index_name, text_column="text", embedding=FakeEmbeddings(1000), From 3d06f42001b2849b3625913e6abb79c43ee754e6 Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Fri, 20 Sep 2024 19:10:45 +0900 Subject: [PATCH 2/2] comments Signed-off-by: B-Step62 --- .../langchain_databricks/vectorstores.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/libs/databricks/langchain_databricks/vectorstores.py b/libs/databricks/langchain_databricks/vectorstores.py index c2c01a1..097b468 100644 --- a/libs/databricks/langchain_databricks/vectorstores.py +++ b/libs/databricks/langchain_databricks/vectorstores.py @@ -60,8 +60,23 @@ class DatabricksVectorSearch(VectorStore): Key init args — indexing params: - endpoint: The name of the Databricks Vector Search endpoint. index_name: The name of the index to use. Format: "catalog.schema.index". + endpoint: The name of the Databricks Vector Search endpoint. If not specified, + the endpoint name is automatically inferred based on the index name. + + .. note:: + + If you are using `databricks-vectorsearch` version < 0.35, the `endpoint` parameter + is required when initializing the vector store. + + .. code-block:: python + + vector_store = DatabricksVectorSearch( + endpoint="", + index_name="", + ... + ) + embedding: The embedding model. Required for direct-access index or delta-sync index with self-managed embeddings. @@ -106,19 +121,6 @@ class DatabricksVectorSearch(VectorStore): text_column="document_content" ) - .. note:: - - If you are using `databricks-vectorsearch` version earlier than 0.35, you also need to - provide the `endpoint` parameter when initializing the vector store. - - .. code-block:: python - - vector_store = DatabricksVectorSearch( - endpoint="", - index_name="", - ... - ) - Add Documents: .. code-block:: python from langchain_core.documents import Document