forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: Self query retriever for HANA Cloud Vector Engine (l…
…angchain-ai#24494) Description: - This PR adds a self query retriever implementation for SAP HANA Cloud Vector Engine. The retriever supports all operators except for contains. - Issue: N/A - Dependencies: no new dependencies added **Add tests and docs:** Added integration tests to: libs/community/tests/unit_tests/query_constructors/test_hanavector.py **Documentation for self query retriever:** /docs/integrations/retrievers/self_query/hanavector_self_query.ipynb --------- Co-authored-by: Bagatur <[email protected]> Co-authored-by: Bagatur <[email protected]>
- Loading branch information
1 parent
4f3b4fc
commit b65ac8d
Showing
5 changed files
with
408 additions
and
2 deletions.
There are no files selected for viewing
246 changes: 246 additions & 0 deletions
246
docs/docs/integrations/retrievers/self_query/hanavector_self_query.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# SAP HANA Cloud Vector Engine\n", | ||
"\n", | ||
"For more information on how to setup the SAP HANA vetor store, take a look at the [documentation](/docs/integrations/vectorstores/sap_hanavector.ipynb).\n", | ||
"\n", | ||
"We use the same setup here:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"# Use OPENAI_API_KEY env variable\n", | ||
"# os.environ[\"OPENAI_API_KEY\"] = \"Your OpenAI API key\"\n", | ||
"from hdbcli import dbapi\n", | ||
"\n", | ||
"# Use connection settings from the environment\n", | ||
"connection = dbapi.connect(\n", | ||
" address=os.environ.get(\"HANA_DB_ADDRESS\"),\n", | ||
" port=os.environ.get(\"HANA_DB_PORT\"),\n", | ||
" user=os.environ.get(\"HANA_DB_USER\"),\n", | ||
" password=os.environ.get(\"HANA_DB_PASSWORD\"),\n", | ||
" autocommit=True,\n", | ||
" sslValidateCertificate=False,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"To be able to self query with good performance we create additional metadata fields\n", | ||
"for our vectorstore table in HANA:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create custom table with attribute\n", | ||
"cur = connection.cursor()\n", | ||
"cur.execute(\"DROP TABLE LANGCHAIN_DEMO_SELF_QUERY\", ignoreErrors=True)\n", | ||
"cur.execute(\n", | ||
" (\n", | ||
" \"\"\"CREATE TABLE \"LANGCHAIN_DEMO_SELF_QUERY\" (\n", | ||
" \"name\" NVARCHAR(100), \"is_active\" BOOLEAN, \"id\" INTEGER, \"height\" DOUBLE,\n", | ||
" \"VEC_TEXT\" NCLOB, \n", | ||
" \"VEC_META\" NCLOB, \n", | ||
" \"VEC_VECTOR\" REAL_VECTOR\n", | ||
" )\"\"\"\n", | ||
" )\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Let's add some documents." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.vectorstores.hanavector import HanaDB\n", | ||
"from langchain_core.documents import Document\n", | ||
"from langchain_openai import OpenAIEmbeddings\n", | ||
"\n", | ||
"embeddings = OpenAIEmbeddings()\n", | ||
"\n", | ||
"# Prepare some test documents\n", | ||
"docs = [\n", | ||
" Document(\n", | ||
" page_content=\"First\",\n", | ||
" metadata={\"name\": \"adam\", \"is_active\": True, \"id\": 1, \"height\": 10.0},\n", | ||
" ),\n", | ||
" Document(\n", | ||
" page_content=\"Second\",\n", | ||
" metadata={\"name\": \"bob\", \"is_active\": False, \"id\": 2, \"height\": 5.7},\n", | ||
" ),\n", | ||
" Document(\n", | ||
" page_content=\"Third\",\n", | ||
" metadata={\"name\": \"jane\", \"is_active\": True, \"id\": 3, \"height\": 2.4},\n", | ||
" ),\n", | ||
"]\n", | ||
"\n", | ||
"db = HanaDB(\n", | ||
" connection=connection,\n", | ||
" embedding=embeddings,\n", | ||
" table_name=\"LANGCHAIN_DEMO_SELF_QUERY\",\n", | ||
" specific_metadata_columns=[\"name\", \"is_active\", \"id\", \"height\"],\n", | ||
")\n", | ||
"\n", | ||
"# Delete already existing documents from the table\n", | ||
"db.delete(filter={})\n", | ||
"db.add_documents(docs)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Self querying\n", | ||
"\n", | ||
"Now for the main act: here is how to construct a SelfQueryRetriever for HANA vectorstore:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.chains.query_constructor.base import AttributeInfo\n", | ||
"from langchain.retrievers.self_query.base import SelfQueryRetriever\n", | ||
"from langchain_community.query_constructors.hanavector import HanaTranslator\n", | ||
"from langchain_openai import ChatOpenAI\n", | ||
"\n", | ||
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")\n", | ||
"\n", | ||
"metadata_field_info = [\n", | ||
" AttributeInfo(\n", | ||
" name=\"name\",\n", | ||
" description=\"The name of the person\",\n", | ||
" type=\"string\",\n", | ||
" ),\n", | ||
" AttributeInfo(\n", | ||
" name=\"is_active\",\n", | ||
" description=\"Whether the person is active\",\n", | ||
" type=\"boolean\",\n", | ||
" ),\n", | ||
" AttributeInfo(\n", | ||
" name=\"id\",\n", | ||
" description=\"The ID of the person\",\n", | ||
" type=\"integer\",\n", | ||
" ),\n", | ||
" AttributeInfo(\n", | ||
" name=\"height\",\n", | ||
" description=\"The height of the person\",\n", | ||
" type=\"float\",\n", | ||
" ),\n", | ||
"]\n", | ||
"\n", | ||
"document_content_description = \"A collection of persons\"\n", | ||
"\n", | ||
"hana_translator = HanaTranslator()\n", | ||
"\n", | ||
"retriever = SelfQueryRetriever.from_llm(\n", | ||
" llm,\n", | ||
" db,\n", | ||
" document_content_description,\n", | ||
" metadata_field_info,\n", | ||
" structured_query_translator=hana_translator,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Let's use this retriever to prepare a (self) query for a person:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"query_prompt = \"Which person is not active?\"\n", | ||
"\n", | ||
"docs = retriever.invoke(input=query_prompt)\n", | ||
"for doc in docs:\n", | ||
" print(\"-\" * 80)\n", | ||
" print(doc.page_content, \" \", doc.metadata)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can also take a look at how the query is being constructed:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.chains.query_constructor.base import (\n", | ||
" StructuredQueryOutputParser,\n", | ||
" get_query_constructor_prompt,\n", | ||
")\n", | ||
"\n", | ||
"prompt = get_query_constructor_prompt(\n", | ||
" document_content_description,\n", | ||
" metadata_field_info,\n", | ||
")\n", | ||
"output_parser = StructuredQueryOutputParser.from_components()\n", | ||
"query_constructor = prompt | llm | output_parser\n", | ||
"\n", | ||
"sq = query_constructor.invoke(input=query_prompt)\n", | ||
"\n", | ||
"print(\"Structured query: \", sq)\n", | ||
"\n", | ||
"print(\"Translated for hana vector store: \", hana_translator.visit_structured_query(sq))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
57 changes: 57 additions & 0 deletions
57
libs/community/langchain_community/query_constructors/hanavector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# HANA Translator/query constructor | ||
from typing import Dict, Tuple, Union | ||
|
||
from langchain_core.structured_query import ( | ||
Comparator, | ||
Comparison, | ||
Operation, | ||
Operator, | ||
StructuredQuery, | ||
Visitor, | ||
) | ||
|
||
|
||
class HanaTranslator(Visitor): | ||
""" | ||
Translate internal query language elements to valid filters params for | ||
HANA vectorstore. | ||
""" | ||
|
||
allowed_operators = [Operator.AND, Operator.OR] | ||
"""Subset of allowed logical operators.""" | ||
allowed_comparators = [ | ||
Comparator.EQ, | ||
Comparator.NE, | ||
Comparator.GT, | ||
Comparator.LT, | ||
Comparator.GTE, | ||
Comparator.LTE, | ||
Comparator.IN, | ||
Comparator.NIN, | ||
# Comparator.CONTAIN, | ||
Comparator.LIKE, | ||
] | ||
|
||
def _format_func(self, func: Union[Operator, Comparator]) -> str: | ||
self._validate_func(func) | ||
return f"${func.value}" | ||
|
||
def visit_operation(self, operation: Operation) -> Dict: | ||
args = [arg.accept(self) for arg in operation.arguments] | ||
return {self._format_func(operation.operator): args} | ||
|
||
def visit_comparison(self, comparison: Comparison) -> Dict: | ||
return { | ||
comparison.attribute: { | ||
self._format_func(comparison.comparator): comparison.value | ||
} | ||
} | ||
|
||
def visit_structured_query( | ||
self, structured_query: StructuredQuery | ||
) -> Tuple[str, dict]: | ||
if structured_query.filter is None: | ||
kwargs = {} | ||
else: | ||
kwargs = {"filter": structured_query.filter.accept(self)} | ||
return structured_query.query, kwargs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.