Skip to content

Commit

Permalink
community[minor]: Self query retriever for HANA Cloud Vector Engine (l…
Browse files Browse the repository at this point in the history
…angchain-ai#24494)

Description:

- This PR adds a self query retriever implementation for SAP HANA Cloud
Vector Engine. The retriever supports all operators except for contains.
- Issue: N/A
- Dependencies: no new dependencies added

**Add tests and docs:**
Added integration tests to:
libs/community/tests/unit_tests/query_constructors/test_hanavector.py

**Documentation for self query retriever:**
/docs/integrations/retrievers/self_query/hanavector_self_query.ipynb

---------

Co-authored-by: Bagatur <[email protected]>
Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
3 people authored Jul 26, 2024
1 parent 4f3b4fc commit b65ac8d
Show file tree
Hide file tree
Showing 5 changed files with 408 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SAP HANA Cloud Vector Engine\n",
"\n",
"For more information on how to setup the SAP HANA vetor store, take a look at the [documentation](/docs/integrations/vectorstores/sap_hanavector.ipynb).\n",
"\n",
"We use the same setup here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Use OPENAI_API_KEY env variable\n",
"# os.environ[\"OPENAI_API_KEY\"] = \"Your OpenAI API key\"\n",
"from hdbcli import dbapi\n",
"\n",
"# Use connection settings from the environment\n",
"connection = dbapi.connect(\n",
" address=os.environ.get(\"HANA_DB_ADDRESS\"),\n",
" port=os.environ.get(\"HANA_DB_PORT\"),\n",
" user=os.environ.get(\"HANA_DB_USER\"),\n",
" password=os.environ.get(\"HANA_DB_PASSWORD\"),\n",
" autocommit=True,\n",
" sslValidateCertificate=False,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To be able to self query with good performance we create additional metadata fields\n",
"for our vectorstore table in HANA:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create custom table with attribute\n",
"cur = connection.cursor()\n",
"cur.execute(\"DROP TABLE LANGCHAIN_DEMO_SELF_QUERY\", ignoreErrors=True)\n",
"cur.execute(\n",
" (\n",
" \"\"\"CREATE TABLE \"LANGCHAIN_DEMO_SELF_QUERY\" (\n",
" \"name\" NVARCHAR(100), \"is_active\" BOOLEAN, \"id\" INTEGER, \"height\" DOUBLE,\n",
" \"VEC_TEXT\" NCLOB, \n",
" \"VEC_META\" NCLOB, \n",
" \"VEC_VECTOR\" REAL_VECTOR\n",
" )\"\"\"\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's add some documents."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores.hanavector import HanaDB\n",
"from langchain_core.documents import Document\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"# Prepare some test documents\n",
"docs = [\n",
" Document(\n",
" page_content=\"First\",\n",
" metadata={\"name\": \"adam\", \"is_active\": True, \"id\": 1, \"height\": 10.0},\n",
" ),\n",
" Document(\n",
" page_content=\"Second\",\n",
" metadata={\"name\": \"bob\", \"is_active\": False, \"id\": 2, \"height\": 5.7},\n",
" ),\n",
" Document(\n",
" page_content=\"Third\",\n",
" metadata={\"name\": \"jane\", \"is_active\": True, \"id\": 3, \"height\": 2.4},\n",
" ),\n",
"]\n",
"\n",
"db = HanaDB(\n",
" connection=connection,\n",
" embedding=embeddings,\n",
" table_name=\"LANGCHAIN_DEMO_SELF_QUERY\",\n",
" specific_metadata_columns=[\"name\", \"is_active\", \"id\", \"height\"],\n",
")\n",
"\n",
"# Delete already existing documents from the table\n",
"db.delete(filter={})\n",
"db.add_documents(docs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Self querying\n",
"\n",
"Now for the main act: here is how to construct a SelfQueryRetriever for HANA vectorstore:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.query_constructor.base import AttributeInfo\n",
"from langchain.retrievers.self_query.base import SelfQueryRetriever\n",
"from langchain_community.query_constructors.hanavector import HanaTranslator\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")\n",
"\n",
"metadata_field_info = [\n",
" AttributeInfo(\n",
" name=\"name\",\n",
" description=\"The name of the person\",\n",
" type=\"string\",\n",
" ),\n",
" AttributeInfo(\n",
" name=\"is_active\",\n",
" description=\"Whether the person is active\",\n",
" type=\"boolean\",\n",
" ),\n",
" AttributeInfo(\n",
" name=\"id\",\n",
" description=\"The ID of the person\",\n",
" type=\"integer\",\n",
" ),\n",
" AttributeInfo(\n",
" name=\"height\",\n",
" description=\"The height of the person\",\n",
" type=\"float\",\n",
" ),\n",
"]\n",
"\n",
"document_content_description = \"A collection of persons\"\n",
"\n",
"hana_translator = HanaTranslator()\n",
"\n",
"retriever = SelfQueryRetriever.from_llm(\n",
" llm,\n",
" db,\n",
" document_content_description,\n",
" metadata_field_info,\n",
" structured_query_translator=hana_translator,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's use this retriever to prepare a (self) query for a person:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query_prompt = \"Which person is not active?\"\n",
"\n",
"docs = retriever.invoke(input=query_prompt)\n",
"for doc in docs:\n",
" print(\"-\" * 80)\n",
" print(doc.page_content, \" \", doc.metadata)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also take a look at how the query is being constructed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.query_constructor.base import (\n",
" StructuredQueryOutputParser,\n",
" get_query_constructor_prompt,\n",
")\n",
"\n",
"prompt = get_query_constructor_prompt(\n",
" document_content_description,\n",
" metadata_field_info,\n",
")\n",
"output_parser = StructuredQueryOutputParser.from_components()\n",
"query_constructor = prompt | llm | output_parser\n",
"\n",
"sq = query_constructor.invoke(input=query_prompt)\n",
"\n",
"print(\"Structured query: \", sq)\n",
"\n",
"print(\"Translated for hana vector store: \", hana_translator.visit_structured_query(sq))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# HANA Translator/query constructor
from typing import Dict, Tuple, Union

from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)


class HanaTranslator(Visitor):
"""
Translate internal query language elements to valid filters params for
HANA vectorstore.
"""

allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.LT,
Comparator.GTE,
Comparator.LTE,
Comparator.IN,
Comparator.NIN,
# Comparator.CONTAIN,
Comparator.LIKE,
]

def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"

def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}

def visit_comparison(self, comparison: Comparison) -> Dict:
return {
comparison.attribute: {
self._format_func(comparison.comparator): comparison.value
}
}

def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs
13 changes: 11 additions & 2 deletions libs/community/langchain_community/vectorstores/hanavector.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def _check_column( # type: ignore[no-untyped-def]
if column_length is not None and column_length > 0:
if rows[0][1] != column_length:
raise AttributeError(
f"Column {column_name} has the wrong length: {rows[0][1]}"
f"Column {column_name} has the wrong length: {rows[0][1]} "
f"expected: {column_length}"
)
else:
raise AttributeError(f"Column {column_name} does not exist")
Expand Down Expand Up @@ -529,10 +530,18 @@ def _process_filter_object(self, filter): # type: ignore[no-untyped-def]
if special_op in COMPARISONS_TO_SQL:
operator = COMPARISONS_TO_SQL[special_op]
if isinstance(special_val, bool):
query_tuple.append("true" if filter_value else "false")
query_tuple.append("true" if special_val else "false")
elif isinstance(special_val, float):
sql_param = "CAST(? as float)"
query_tuple.append(special_val)
elif (
isinstance(special_val, dict)
and "type" in special_val
and special_val["type"] == "date"
):
# Date type
sql_param = "CAST(? as DATE)"
query_tuple.append(special_val["date"])
else:
query_tuple.append(special_val)
# "$between"
Expand Down
Loading

0 comments on commit b65ac8d

Please sign in to comment.