Skip to content

Commit

Permalink
feat: support opensearch approximate k-NN (langgenius#5322)
Browse files Browse the repository at this point in the history
  • Loading branch information
baojingyu authored Jun 19, 2024
1 parent dfada6d commit 1b68505
Show file tree
Hide file tree
Showing 13 changed files with 524 additions and 4 deletions.
8 changes: 8 additions & 0 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,14 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.OPENSEARCH:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.OPENSEARCH,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")

Expand Down
7 changes: 7 additions & 0 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,13 @@ def __init__(self):
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
self.MILVUS_DATABASE = get_env('MILVUS_DATABASE')

# OpenSearch settings
self.OPENSEARCH_HOST = get_env('OPENSEARCH_HOST')
self.OPENSEARCH_PORT = get_env('OPENSEARCH_PORT')
self.OPENSEARCH_USER = get_env('OPENSEARCH_USER')
self.OPENSEARCH_PASSWORD = get_env('OPENSEARCH_PASSWORD')
self.OPENSEARCH_SECURE = get_bool_env('OPENSEARCH_SECURE')

# weaviate settings
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
Expand Down
4 changes: 2 additions & 2 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def get(self):
'semantic_search'
]
}
case VectorType.QDRANT | VectorType.WEAVIATE:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
Expand All @@ -525,7 +525,7 @@ def get(self, vector_type):
'semantic_search'
]
}
case VectorType.QDRANT | VectorType.WEAVIATE:
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
Expand Down
Empty file.
278 changes: 278 additions & 0 deletions api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
import json
import logging
import ssl
from typing import Any, Optional
from uuid import uuid4

from flask import current_app
from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator

from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

logger = logging.getLogger(__name__)


class OpenSearchConfig(BaseModel):
host: str
port: int
user: Optional[str] = None
password: Optional[str] = None
secure: bool = False

@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values.get('host'):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get('port'):
raise ValueError("config OPENSEARCH_PORT is required")
return values

def create_ssl_context(self) -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
return ssl_context

def to_opensearch_params(self) -> dict[str, Any]:
params = {
'hosts': [{'host': self.host, 'port': self.port}],
'use_ssl': self.secure,
'verify_certs': self.secure,
}
if self.user and self.password:
params['http_auth'] = (self.user, self.password)
if self.secure:
params['ssl_context'] = self.create_ssl_context()
return params


class OpenSearchVector(BaseVector):

def __init__(self, collection_name: str, config: OpenSearchConfig):
super().__init__(collection_name)
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())

def get_type(self) -> str:
return VectorType.OPENSEARCH

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
for i in range(len(documents)):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuid4().hex,
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
}
actions.append(action)

helpers.bulk(self._client, actions)

def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids:
self.delete_by_ids(ids)

def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
response = self._client.search(index=self._collection_name.lower(), body=query)
if response['hits']['hits']:
return [hit['_id'] for hit in response['hits']['hits']]
else:
return None

def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_ids(ids)

def delete_by_ids(self, ids: list[str]) -> None:
index_name = self._collection_name.lower()
if not self._client.indices.exists(index=index_name):
logger.warning(f"Index {index_name} does not exist")
return

# Obtaining All Actual Documents_ID
actual_ids = []

for doc_id in ids:
es_ids = self.get_ids_by_metadata_field('doc_id', doc_id)
if es_ids:
actual_ids.extend(es_ids)
else:
logger.warning(f"Document with metadata doc_id {doc_id} not found for deletion")

if actual_ids:
actions = [{"_op_type": "delete", "_index": index_name, "_id": es_id} for es_id in actual_ids]
try:
helpers.bulk(self._client, actions)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get('delete', {})
status = delete_error.get('status')
doc_id = delete_error.get('_id')

if status == 404:
logger.warning(f"Document not found for deletion: {doc_id}")
else:
logger.error(f"Error deleting document: {error}")

def delete(self) -> None:
self._client.indices.delete(index=self._collection_name.lower())

def text_exists(self, id: str) -> bool:
try:
self._client.get(index=self._collection_name.lower(), id=id)
return True
except:
return False

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Make sure query_vector is a list
if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats")

# Check whether query_vector is a floating-point number list
if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats")

query = {
"size": kwargs.get('top_k', 4),
"query": {
"knn": {
Field.VECTOR.value: {
Field.VECTOR.value: query_vector,
"k": kwargs.get('top_k', 4)
}
}
}
}

try:
response = self._client.search(index=self._collection_name.lower(), body=query)
except Exception as e:
logger.error(f"Error executing search: {e}")
raise

docs = []
for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value, {})

# Make sure metadata is a dictionary
if metadata is None:
metadata = {}

metadata['score'] = hit['_score']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if hit['_score'] > score_threshold:
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}

response = self._client.search(index=self._collection_name.lower(), body=full_text_query)

docs = []
for hit in response['hits']['hits']:
metadata = hit['_source'].get(Field.METADATA_KEY.value)
doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

return docs

def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name.lower()}'
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}'
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name.lower()} already exists.")
return

if not self._client.indices.exists(index=self._collection_name.lower()):
index_body = {
"settings": {
"index": {
"knn": True
}
},
"mappings": {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: {
"type": "knn_vector",
"dimension": len(embeddings[0]), # Make sure the dimension is correct here
"method": {
"name": "hnsw",
"space_type": "l2",
"engine": "faiss",
"parameters": {
"ef_construction": 64,
"m": 8
}
}
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
}
}
}

self._client.indices.create(index=self._collection_name.lower(), body=index_body)

redis_client.set(collection_exist_cache_key, 1, ex=3600)


class OpenSearchVectorFactory(AbstractVectorFactory):

def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))

config = current_app.config

open_search_config = OpenSearchConfig(
host=config.get('OPENSEARCH_HOST'),
port=config.get('OPENSEARCH_PORT'),
user=config.get('OPENSEARCH_USER'),
password=config.get('OPENSEARCH_PASSWORD'),
secure=config.get('OPENSEARCH_SECURE'),
)

return OpenSearchVector(
collection_name=collection_name,
config=open_search_config
)
3 changes: 3 additions & 0 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

Expand Down
1 change: 1 addition & 0 deletions api/core/rag/datasource/vdb/vector_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ class VectorType(str, Enum):
RELYT = 'relyt'
TIDB_VECTOR = 'tidb_vector'
WEAVIATE = 'weaviate'
OPENSEARCH = 'opensearch'
TENCENT = 'tencent'
Empty file added api/events/__init__.py
Empty file.
Loading

0 comments on commit 1b68505

Please sign in to comment.