From 1b685050668f619662b1c5f6b4799aa6df59d619 Mon Sep 17 00:00:00 2001 From: baojingyu <31037754+baojingyu@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:44:33 +0800 Subject: [PATCH] feat: support opensearch approximate k-NN (#5322) --- api/commands.py | 8 + api/config.py | 7 + api/controllers/console/datasets/datasets.py | 4 +- .../rag/datasource/vdb/opensearch/__init__.py | 0 .../vdb/opensearch/opensearch_vector.py | 278 ++++++++++++++++++ api/core/rag/datasource/vdb/vector_factory.py | 3 + api/core/rag/datasource/vdb/vector_type.py | 1 + api/events/__init__.py | 0 api/poetry.lock | 37 ++- api/pyproject.toml | 1 + api/requirements.txt | 3 +- .../vdb/opensearch/__init__.py | 0 .../vdb/opensearch/test_opensearch.py | 186 ++++++++++++ 13 files changed, 524 insertions(+), 4 deletions(-) create mode 100644 api/core/rag/datasource/vdb/opensearch/__init__.py create mode 100644 api/core/rag/datasource/vdb/opensearch/opensearch_vector.py create mode 100644 api/events/__init__.py create mode 100644 api/tests/integration_tests/vdb/opensearch/__init__.py create mode 100644 api/tests/integration_tests/vdb/opensearch/test_opensearch.py diff --git a/api/commands.py b/api/commands.py index f3e0769134f137..91d77370236322 100644 --- a/api/commands.py +++ b/api/commands.py @@ -327,6 +327,14 @@ def migrate_knowledge_vector_database(): "vector_store": {"class_prefix": collection_name} } dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == VectorType.OPENSEARCH: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": VectorType.OPENSEARCH, + "vector_store": {"class_prefix": collection_name} + } + dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/config.py b/api/config.py index 3c62501f2eaeaf..cca05922c1c5c1 100644 --- a/api/config.py +++ b/api/config.py @@ -282,6 +282,13 @@ def __init__(self): self.MILVUS_SECURE = get_env('MILVUS_SECURE') self.MILVUS_DATABASE = get_env('MILVUS_DATABASE') + # OpenSearch settings + self.OPENSEARCH_HOST = get_env('OPENSEARCH_HOST') + self.OPENSEARCH_PORT = get_env('OPENSEARCH_PORT') + self.OPENSEARCH_USER = get_env('OPENSEARCH_USER') + self.OPENSEARCH_PASSWORD = get_env('OPENSEARCH_PASSWORD') + self.OPENSEARCH_SECURE = get_bool_env('OPENSEARCH_SECURE') + # weaviate settings self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT') self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY') diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index cb14abe9231d16..de99b89ef674aa 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -503,7 +503,7 @@ def get(self): 'semantic_search' ] } - case VectorType.QDRANT | VectorType.WEAVIATE: + case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH: return { 'retrieval_method': [ 'semantic_search', 'full_text_search', 'hybrid_search' @@ -525,7 +525,7 @@ def get(self, vector_type): 'semantic_search' ] } - case VectorType.QDRANT | VectorType.WEAVIATE: + case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH: return { 'retrieval_method': [ 'semantic_search', 'full_text_search', 'hybrid_search' diff --git a/api/core/rag/datasource/vdb/opensearch/__init__.py b/api/core/rag/datasource/vdb/opensearch/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py new file mode 100644 index 00000000000000..52f8b41bae0b23 --- /dev/null +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -0,0 +1,278 @@ +import json +import logging +import ssl +from typing import Any, Optional +from uuid import uuid4 + +from flask import current_app +from opensearchpy import OpenSearch, helpers +from opensearchpy.helpers import BulkIndexError +from pydantic import BaseModel, model_validator + +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +class OpenSearchConfig(BaseModel): + host: str + port: int + user: Optional[str] = None + password: Optional[str] = None + secure: bool = False + + @model_validator(mode='before') + def validate_config(cls, values: dict) -> dict: + if not values.get('host'): + raise ValueError("config OPENSEARCH_HOST is required") + if not values.get('port'): + raise ValueError("config OPENSEARCH_PORT is required") + return values + + def create_ssl_context(self) -> ssl.SSLContext: + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation + return ssl_context + + def to_opensearch_params(self) -> dict[str, Any]: + params = { + 'hosts': [{'host': self.host, 'port': self.port}], + 'use_ssl': self.secure, + 'verify_certs': self.secure, + } + if self.user and self.password: + params['http_auth'] = (self.user, self.password) + if self.secure: + params['ssl_context'] = self.create_ssl_context() + return params + + +class OpenSearchVector(BaseVector): + + def __init__(self, collection_name: str, config: OpenSearchConfig): + super().__init__(collection_name) + self._client_config = config + self._client = OpenSearch(**config.to_opensearch_params()) + + def get_type(self) -> str: + return VectorType.OPENSEARCH + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + metadatas = [d.metadata for d in texts] + self.create_collection(embeddings, metadatas) + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + actions = [] + for i in range(len(documents)): + action = { + "_op_type": "index", + "_index": self._collection_name.lower(), + "_id": uuid4().hex, + "_source": { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], # Make sure you pass an array here + Field.METADATA_KEY.value: documents[i].metadata, + } + } + actions.append(action) + + helpers.bulk(self._client, actions) + + def delete_by_document_id(self, document_id: str): + ids = self.get_ids_by_metadata_field('document_id', document_id) + if ids: + self.delete_by_ids(ids) + + def get_ids_by_metadata_field(self, key: str, value: str): + query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} + response = self._client.search(index=self._collection_name.lower(), body=query) + if response['hits']['hits']: + return [hit['_id'] for hit in response['hits']['hits']] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str): + ids = self.get_ids_by_metadata_field(key, value) + if ids: + self.delete_by_ids(ids) + + def delete_by_ids(self, ids: list[str]) -> None: + index_name = self._collection_name.lower() + if not self._client.indices.exists(index=index_name): + logger.warning(f"Index {index_name} does not exist") + return + + # Obtaining All Actual Documents_ID + actual_ids = [] + + for doc_id in ids: + es_ids = self.get_ids_by_metadata_field('doc_id', doc_id) + if es_ids: + actual_ids.extend(es_ids) + else: + logger.warning(f"Document with metadata doc_id {doc_id} not found for deletion") + + if actual_ids: + actions = [{"_op_type": "delete", "_index": index_name, "_id": es_id} for es_id in actual_ids] + try: + helpers.bulk(self._client, actions) + except BulkIndexError as e: + for error in e.errors: + delete_error = error.get('delete', {}) + status = delete_error.get('status') + doc_id = delete_error.get('_id') + + if status == 404: + logger.warning(f"Document not found for deletion: {doc_id}") + else: + logger.error(f"Error deleting document: {error}") + + def delete(self) -> None: + self._client.indices.delete(index=self._collection_name.lower()) + + def text_exists(self, id: str) -> bool: + try: + self._client.get(index=self._collection_name.lower(), id=id) + return True + except: + return False + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + # Make sure query_vector is a list + if not isinstance(query_vector, list): + raise ValueError("query_vector should be a list of floats") + + # Check whether query_vector is a floating-point number list + if not all(isinstance(x, float) for x in query_vector): + raise ValueError("All elements in query_vector should be floats") + + query = { + "size": kwargs.get('top_k', 4), + "query": { + "knn": { + Field.VECTOR.value: { + Field.VECTOR.value: query_vector, + "k": kwargs.get('top_k', 4) + } + } + } + } + + try: + response = self._client.search(index=self._collection_name.lower(), body=query) + except Exception as e: + logger.error(f"Error executing search: {e}") + raise + + docs = [] + for hit in response['hits']['hits']: + metadata = hit['_source'].get(Field.METADATA_KEY.value, {}) + + # Make sure metadata is a dictionary + if metadata is None: + metadata = {} + + metadata['score'] = hit['_score'] + score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + if hit['_score'] > score_threshold: + doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) + docs.append(doc) + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}} + + response = self._client.search(index=self._collection_name.lower(), body=full_text_query) + + docs = [] + for hit in response['hits']['hits']: + metadata = hit['_source'].get(Field.METADATA_KEY.value) + doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) + docs.append(doc) + + return docs + + def create_collection( + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + ): + lock_name = f'vector_indexing_lock_{self._collection_name.lower()}' + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}' + if redis_client.get(collection_exist_cache_key): + logger.info(f"Collection {self._collection_name.lower()} already exists.") + return + + if not self._client.indices.exists(index=self._collection_name.lower()): + index_body = { + "settings": { + "index": { + "knn": True + } + }, + "mappings": { + "properties": { + Field.CONTENT_KEY.value: {"type": "text"}, + Field.VECTOR.value: { + "type": "knn_vector", + "dimension": len(embeddings[0]), # Make sure the dimension is correct here + "method": { + "name": "hnsw", + "space_type": "l2", + "engine": "faiss", + "parameters": { + "ef_construction": 64, + "m": 8 + } + } + }, + Field.METADATA_KEY.value: { + "type": "object", + "properties": { + "doc_id": {"type": "keyword"} # Map doc_id to keyword type + } + } + } + } + } + + self._client.indices.create(index=self._collection_name.lower(), body=index_body) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + +class OpenSearchVectorFactory(AbstractVectorFactory): + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps( + self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) + + config = current_app.config + + open_search_config = OpenSearchConfig( + host=config.get('OPENSEARCH_HOST'), + port=config.get('OPENSEARCH_PORT'), + user=config.get('OPENSEARCH_USER'), + password=config.get('OPENSEARCH_PASSWORD'), + secure=config.get('OPENSEARCH_SECURE'), + ) + + return OpenSearchVector( + collection_name=collection_name, + config=open_search_config + ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 48f18df31f0be2..8882cb2170b3a8 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -78,6 +78,9 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: case VectorType.TENCENT: from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory return TencentVectorFactory + case VectorType.OPENSEARCH: + from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory + return OpenSearchVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index aba4f757507329..4a27e52706d369 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -10,4 +10,5 @@ class VectorType(str, Enum): RELYT = 'relyt' TIDB_VECTOR = 'tidb_vector' WEAVIATE = 'weaviate' + OPENSEARCH = 'opensearch' TENCENT = 'tencent' diff --git a/api/events/__init__.py b/api/events/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/poetry.lock b/api/poetry.lock index fe8e1ebb28e8a7..2dfed17be2261e 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -4891,6 +4891,30 @@ files = [ [package.dependencies] et-xmlfile = "*" +[[package]] +name = "opensearch-py" +version = "2.4.0" +description = "Python client for OpenSearch" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4" +files = [ + {file = "opensearch-py-2.4.0.tar.gz", hash = "sha256:7eba2b6ed2ddcf33225bfebfba2aee026877838cc39f760ec80f27827308cc4b"}, + {file = "opensearch_py-2.4.0-py2.py3-none-any.whl", hash = "sha256:316077235437c8ceac970232261f3393c65fb92a80f33c5b106f50f1dab24fd9"}, +] + +[package.dependencies] +certifi = ">=2022.12.07" +python-dateutil = "*" +requests = ">=2.4.0,<3.0.0" +six = "*" +urllib3 = ">=1.26.18" + +[package.extras] +async = ["aiohttp (>=3,<4)"] +develop = ["black", "botocore", "coverage (<8.0.0)", "jinja2", "mock", "myst-parser", "pytest (>=3.0.0)", "pytest-cov", "pytest-mock (<4.0.0)", "pytz", "pyyaml", "requests (>=2.0.0,<3.0.0)", "sphinx", "sphinx-copybutton", "sphinx-rtd-theme"] +docs = ["aiohttp (>=3,<4)", "myst-parser", "sphinx", "sphinx-copybutton", "sphinx-rtd-theme"] +kerberos = ["requests-kerberos"] + [[package]] name = "opentelemetry-api" version = "1.25.0" @@ -6414,6 +6438,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -6421,8 +6446,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -6439,6 +6472,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -6446,6 +6480,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -8944,4 +8979,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "21360e271c46e0368b8e3bd26287caca73145a73ee73287669f91e7eac6f05b9" +content-hash = "367a4b0ad745a48263dd44711be28c4c076dee983e3f5d1ac56c22bbb2eed531" diff --git a/api/pyproject.toml b/api/pyproject.toml index a83d98b43842df..e4d10b7de9808d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -185,6 +185,7 @@ chromadb = "~0.5.1" tenacity = "~8.3.0" cos-python-sdk-v5 = "1.9.30" novita-client = "^0.5.6" +opensearch-py = "2.4.0" [tool.poetry.group.dev] optional = true diff --git a/api/requirements.txt b/api/requirements.txt index 7ab636e226e604..1b618cb3f44ca4 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -90,4 +90,5 @@ tencentcloud-sdk-python-hunyuan~=3.0.1158 chromadb~=0.5.1 novita_client~=0.5.6 tenacity~=8.3.0 -cos-python-sdk-v5==1.9.30 +opensearch-py==2.4.0 +cos-python-sdk-v5==1.9.30 \ No newline at end of file diff --git a/api/tests/integration_tests/vdb/opensearch/__init__.py b/api/tests/integration_tests/vdb/opensearch/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py new file mode 100644 index 00000000000000..e372c9b7ac7d28 --- /dev/null +++ b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py @@ -0,0 +1,186 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchConfig, OpenSearchVector +from core.rag.models.document import Document +from extensions import ext_redis + + +def get_example_text() -> str: + return "This is a sample text for testing purposes." + + +@pytest.fixture(scope="module") +def setup_mock_redis(): + ext_redis.redis_client.get = MagicMock(return_value=None) + ext_redis.redis_client.set = MagicMock(return_value=None) + + mock_redis_lock = MagicMock() + mock_redis_lock.__enter__ = MagicMock() + mock_redis_lock.__exit__ = MagicMock() + ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock) + + +class TestOpenSearchVector: + def setup_method(self): + self.collection_name = "test_collection" + self.example_doc_id = "example_doc_id" + self.vector = OpenSearchVector( + collection_name=self.collection_name, + config=OpenSearchConfig( + host='localhost', + port=9200, + user='admin', + password='password', + secure=False + ) + ) + self.vector._client = MagicMock() + + @pytest.mark.parametrize("search_response, expected_length, expected_doc_id", [ + ({ + 'hits': { + 'total': {'value': 1}, + 'hits': [ + {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} + ] + } + }, 1, "example_doc_id"), + ({ + 'hits': { + 'total': {'value': 0}, + 'hits': [] + } + }, 0, None) + ]) + def test_search_by_full_text(self, search_response, expected_length, expected_doc_id): + self.vector._client.search.return_value = search_response + + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == expected_length + if expected_length > 0: + assert hits_by_full_text[0].metadata['document_id'] == expected_doc_id + + def test_search_by_vector(self): + vector = [0.1] * 128 + mock_response = { + 'hits': { + 'total': {'value': 1}, + 'hits': [ + { + '_source': { + Field.CONTENT_KEY.value: get_example_text(), + Field.METADATA_KEY.value: {"document_id": self.example_doc_id} + }, + '_score': 1.0 + } + ] + } + } + self.vector._client.search.return_value = mock_response + + hits_by_vector = self.vector.search_by_vector(query_vector=vector) + + print("Hits by vector:", hits_by_vector) + print("Expected document ID:", self.example_doc_id) + print("Actual document ID:", hits_by_vector[0].metadata['document_id'] if hits_by_vector else "No hits") + + assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}" + assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \ + f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" + + def test_delete_by_document_id(self): + self.vector._client.delete_by_query.return_value = {'deleted': 1} + + doc = Document(page_content="Test content to delete", metadata={"document_id": self.example_doc_id}) + embedding = [0.1] * 128 + + with patch('opensearchpy.helpers.bulk') as mock_bulk: + mock_bulk.return_value = ([], []) + self.vector.add_texts([doc], [embedding]) + + self.vector.delete_by_document_id(document_id=self.example_doc_id) + + self.vector._client.search.return_value = {'hits': {'total': {'value': 0}, 'hits': []}} + + ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + assert ids is None or len(ids) == 0 + + def test_get_ids_by_metadata_field(self): + mock_response = { + 'hits': { + 'total': {'value': 1}, + 'hits': [{'_id': 'mock_id'}] + } + } + self.vector._client.search.return_value = mock_response + + doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) + embedding = [0.1] * 128 + + with patch('opensearchpy.helpers.bulk') as mock_bulk: + mock_bulk.return_value = ([], []) + self.vector.add_texts([doc], [embedding]) + + ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + assert len(ids) == 1 + assert ids[0] == 'mock_id' + + def test_add_texts(self): + self.vector._client.index.return_value = {'result': 'created'} + + doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) + embedding = [0.1] * 128 + + with patch('opensearchpy.helpers.bulk') as mock_bulk: + mock_bulk.return_value = ([], []) + self.vector.add_texts([doc], [embedding]) + + mock_response = { + 'hits': { + 'total': {'value': 1}, + 'hits': [{'_id': 'mock_id'}] + } + } + self.vector._client.search.return_value = mock_response + + ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id) + assert len(ids) == 1 + assert ids[0] == 'mock_id' + +@pytest.mark.usefixtures("setup_mock_redis") +class TestOpenSearchVectorWithRedis: + def setup_method(self): + self.tester = TestOpenSearchVector() + + def test_search_by_full_text(self): + self.tester.setup_method() + search_response = { + 'hits': { + 'total': {'value': 1}, + 'hits': [ + {'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}} + ] + } + } + expected_length = 1 + expected_doc_id = "example_doc_id" + self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id) + + def test_delete_by_document_id(self): + self.tester.setup_method() + self.tester.test_delete_by_document_id() + + def test_get_ids_by_metadata_field(self): + self.tester.setup_method() + self.tester.test_get_ids_by_metadata_field() + + def test_add_texts(self): + self.tester.setup_method() + self.tester.test_add_texts() + + def test_search_by_vector(self): + self.tester.setup_method() + self.tester.test_search_by_vector()