From 3c735a373169a7a3a5ab41f0801a94207b1ecab9 Mon Sep 17 00:00:00 2001 From: Kenn Date: Tue, 27 Aug 2024 11:43:44 +0800 Subject: [PATCH] feat: rewrite Elasticsearch index and search code to achieve Elasticsearch vector and full-text search (#7641) Co-authored-by: haokai Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Bowen Liang Co-authored-by: wellCh4n --- api/configs/middleware/__init__.py | 2 + .../middleware/vdb/elasticsearch_config.py | 30 ++++ .../vdb/elasticsearch/elasticsearch_vector.py | 131 +++++++++++------- 3 files changed, 111 insertions(+), 52 deletions(-) create mode 100644 api/configs/middleware/vdb/elasticsearch_config.py diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 05e9b8f7a6eede..f25979e5d8f775 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -13,6 +13,7 @@ from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig from configs.middleware.vdb.chroma_config import ChromaConfig +from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.myscale_config import MyScaleConfig from configs.middleware.vdb.opensearch_config import OpenSearchConfig @@ -200,5 +201,6 @@ class MiddlewareConfig( TencentVectorDBConfig, TiDBVectorConfig, WeaviateConfig, + ElasticsearchConfig, ): pass diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py new file mode 100644 index 00000000000000..5b6a8fd939c292 --- /dev/null +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -0,0 +1,30 @@ +from typing import Optional + +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings + + +class ElasticsearchConfig(BaseSettings): + """ + Elasticsearch configs + """ + + ELASTICSEARCH_HOST: Optional[str] = Field( + description="Elasticsearch host", + default="127.0.0.1", + ) + + ELASTICSEARCH_PORT: PositiveInt = Field( + description="Elasticsearch port", + default=9200, + ) + + ELASTICSEARCH_USERNAME: Optional[str] = Field( + description="Elasticsearch username", + default="elastic", + ) + + ELASTICSEARCH_PASSWORD: Optional[str] = Field( + description="Elasticsearch password", + default="elastic", + ) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 01ba6fb3248786..233539756fa84f 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,5 +1,7 @@ import json -from typing import Any +import logging +from typing import Any, Optional +from urllib.parse import urlparse import requests from elasticsearch import Elasticsearch @@ -7,16 +9,20 @@ from pydantic import BaseModel, model_validator from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.models.document import Document +from extensions.ext_redis import redis_client from models.dataset import Dataset +logger = logging.getLogger(__name__) + class ElasticSearchConfig(BaseModel): host: str - port: str + port: int username: str password: str @@ -37,12 +43,19 @@ class ElasticSearchVector(BaseVector): def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list): super().__init__(index_name.lower()) self._client = self._init_client(config) + self._version = self._get_version() + self._check_version() self._attributes = attributes def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: + parsed_url = urlparse(config.host) + if parsed_url.scheme in ['http', 'https']: + hosts = f'{config.host}:{config.port}' + else: + hosts = f'http://{config.host}:{config.port}' client = Elasticsearch( - hosts=f'{config.host}:{config.port}', + hosts=hosts, basic_auth=(config.username, config.password), request_timeout=100000, retry_on_timeout=True, @@ -53,42 +66,27 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: return client + def _get_version(self) -> str: + info = self._client.info() + return info['version']['number'] + + def _check_version(self): + if self._version < '8.0.0': + raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") + def get_type(self) -> str: return 'elasticsearch' def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) - texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] - - if not self._client.indices.exists(index=self._collection_name): - dim = len(embeddings[0]) - mapping = { - "properties": { - "text": { - "type": "text" - }, - "vector": { - "type": "dense_vector", - "index": True, - "dims": dim, - "similarity": "l2_norm" - }, - } - } - self._client.indices.create(index=self._collection_name, mappings=mapping) - - added_ids = [] - for i, text in enumerate(texts): + for i in range(len(documents)): self._client.index(index=self._collection_name, id=uuids[i], document={ - "text": text, - "vector": embeddings[i] if embeddings[i] else None, - "metadata": metadatas[i] if metadatas[i] else {}, + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i] if embeddings[i] else None, + Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {} }) - added_ids.append(uuids[i]) - self._client.indices.refresh(index=self._collection_name) return uuids @@ -116,28 +114,21 @@ def delete(self) -> None: self._client.indices.delete(index=self._collection_name) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - query_str = { - "query": { - "script_score": { - "query": { - "match_all": {} - }, - "script": { - "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", - "params": { - "query_vector": query_vector - } - } - } - } + top_k = kwargs.get("top_k", 10) + knn = { + "field": Field.VECTOR.value, + "query_vector": query_vector, + "k": top_k } - results = self._client.search(index=self._collection_name, body=query_str) + results = self._client.search(index=self._collection_name, knn=knn, size=top_k) docs_and_scores = [] for hit in results['hits']['hits']: docs_and_scores.append( - (Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score'])) + (Document(page_content=hit['_source'][Field.CONTENT_KEY.value], + vector=hit['_source'][Field.VECTOR.value], + metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score'])) docs = [] for doc, score in docs_and_scores: @@ -146,25 +137,61 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc doc.metadata['score'] = score docs.append(doc) - # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) - return docs + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: query_str = { "match": { - "text": query + Field.CONTENT_KEY.value: query } } results = self._client.search(index=self._collection_name, query=query_str) docs = [] for hit in results['hits']['hits']: - docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata'])) + docs.append(Document( + page_content=hit['_source'][Field.CONTENT_KEY.value], + vector=hit['_source'][Field.VECTOR.value], + metadata=hit['_source'][Field.METADATA_KEY.value], + )) return docs def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - return self.add_texts(texts, embeddings, **kwargs) + metadatas = [d.metadata for d in texts] + self.create_collection(embeddings, metadatas) + self.add_texts(texts, embeddings, **kwargs) + + def create_collection( + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + ): + lock_name = f'vector_indexing_lock_{self._collection_name}' + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f'vector_indexing_{self._collection_name}' + if redis_client.get(collection_exist_cache_key): + logger.info(f"Collection {self._collection_name} already exists.") + return + + if not self._client.indices.exists(index=self._collection_name): + dim = len(embeddings[0]) + mappings = { + "properties": { + Field.CONTENT_KEY.value: {"type": "text"}, + Field.VECTOR.value: { # Make sure the dimension is correct here + "type": "dense_vector", + "dims": dim, + "similarity": "cosine" + }, + Field.METADATA_KEY.value: { + "type": "object", + "properties": { + "doc_id": {"type": "keyword"} # Map doc_id to keyword type + } + } + } + } + self._client.indices.create(index=self._collection_name, mappings=mappings) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) class ElasticSearchVectorFactory(AbstractVectorFactory):