Skip to content

Commit

Permalink
feat: rewrite Elasticsearch index and search code to achieve Elastics…
Browse files Browse the repository at this point in the history
…earch vector and full-text search (#7641)

Co-authored-by: haokai <[email protected]>
Co-authored-by: crazywoola <[email protected]>
Co-authored-by: Bowen Liang <[email protected]>
Co-authored-by: wellCh4n <[email protected]>
  • Loading branch information
5 people authored Aug 27, 2024
1 parent e7afee1 commit 122ce41
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 52 deletions.
2 changes: 2 additions & 0 deletions api/configs/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
Expand Down Expand Up @@ -200,5 +201,6 @@ class MiddlewareConfig(
TencentVectorDBConfig,
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
):
pass
30 changes: 30 additions & 0 deletions api/configs/middleware/vdb/elasticsearch_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional

from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings


class ElasticsearchConfig(BaseSettings):
"""
Elasticsearch configs
"""

ELASTICSEARCH_HOST: Optional[str] = Field(
description="Elasticsearch host",
default="127.0.0.1",
)

ELASTICSEARCH_PORT: PositiveInt = Field(
description="Elasticsearch port",
default=9200,
)

ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Elasticsearch username",
default="elastic",
)

ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Elasticsearch password",
default="elastic",
)
131 changes: 79 additions & 52 deletions api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
import json
from typing import Any
import logging
from typing import Any, Optional
from urllib.parse import urlparse

import requests
from elasticsearch import Elasticsearch
from flask import current_app
from pydantic import BaseModel, model_validator

from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

logger = logging.getLogger(__name__)


class ElasticSearchConfig(BaseModel):
host: str
port: str
port: int
username: str
password: str

Expand All @@ -37,12 +43,19 @@ class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
super().__init__(index_name.lower())
self._client = self._init_client(config)
self._version = self._get_version()
self._check_version()
self._attributes = attributes

def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
parsed_url = urlparse(config.host)
if parsed_url.scheme in ['http', 'https']:
hosts = f'{config.host}:{config.port}'
else:
hosts = f'http://{config.host}:{config.port}'
client = Elasticsearch(
hosts=f'{config.host}:{config.port}',
hosts=hosts,
basic_auth=(config.username, config.password),
request_timeout=100000,
retry_on_timeout=True,
Expand All @@ -53,42 +66,27 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:

return client

def _get_version(self) -> str:
info = self._client.info()
return info['version']['number']

def _check_version(self):
if self._version < '8.0.0':
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")

def get_type(self) -> str:
return 'elasticsearch'

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]

if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mapping = {
"properties": {
"text": {
"type": "text"
},
"vector": {
"type": "dense_vector",
"index": True,
"dims": dim,
"similarity": "l2_norm"
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mapping)

added_ids = []
for i, text in enumerate(texts):
for i in range(len(documents)):
self._client.index(index=self._collection_name,
id=uuids[i],
document={
"text": text,
"vector": embeddings[i] if embeddings[i] else None,
"metadata": metadatas[i] if metadatas[i] else {},
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
})
added_ids.append(uuids[i])

self._client.indices.refresh(index=self._collection_name)
return uuids

Expand Down Expand Up @@ -116,28 +114,21 @@ def delete(self) -> None:
self._client.indices.delete(index=self._collection_name)

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_str = {
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {
"query_vector": query_vector
}
}
}
}
top_k = kwargs.get("top_k", 10)
knn = {
"field": Field.VECTOR.value,
"query_vector": query_vector,
"k": top_k
}

results = self._client.search(index=self._collection_name, body=query_str)
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)

docs_and_scores = []
for hit in results['hits']['hits']:
docs_and_scores.append(
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))

docs = []
for doc, score in docs_and_scores:
Expand All @@ -146,25 +137,61 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
doc.metadata['score'] = score
docs.append(doc)

# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)

return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {
"match": {
"text": query
Field.CONTENT_KEY.value: query
}
}
results = self._client.search(index=self._collection_name, query=query_str)
docs = []
for hit in results['hits']['hits']:
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
docs.append(Document(
page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value],
))

return docs

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
return self.add_texts(texts, embeddings, **kwargs)
metadatas = [d.metadata for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)

def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name}'
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return

if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"similarity": "cosine"
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
}
}
self._client.indices.create(index=self._collection_name, mappings=mappings)

redis_client.set(collection_exist_cache_key, 1, ex=3600)


class ElasticSearchVectorFactory(AbstractVectorFactory):
Expand Down

0 comments on commit 122ce41

Please sign in to comment.