Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rewrite Elasticsearch index and search code to achieve Elasticsearch vector and full-text search #7641

Merged
merged 5 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/configs/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
Expand Down Expand Up @@ -200,5 +201,6 @@ class MiddlewareConfig(
TencentVectorDBConfig,
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
):
pass
30 changes: 30 additions & 0 deletions api/configs/middleware/vdb/elasticsearch_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional

from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings


class ElasticsearchConfig(BaseSettings):
"""
Elasticsearch configs
"""

ELASTICSEARCH_HOST: Optional[str] = Field(
description="Elasticsearch host",
default="127.0.0.1",
)

ELASTICSEARCH_PORT: PositiveInt = Field(
description="Elasticsearch port",
default=9200,
)

ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Elasticsearch username",
default="elastic",
)

ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Elasticsearch password",
default="elastic",
)
131 changes: 79 additions & 52 deletions api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
import json
from typing import Any
import logging
from typing import Any, Optional
from urllib.parse import urlparse

import requests
from elasticsearch import Elasticsearch
from flask import current_app
from pydantic import BaseModel, model_validator

from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

logger = logging.getLogger(__name__)


class ElasticSearchConfig(BaseModel):
host: str
port: str
port: int
username: str
password: str

Expand All @@ -37,12 +43,19 @@ class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
super().__init__(index_name.lower())
self._client = self._init_client(config)
self._version = self._get_version()
self._check_version()
KennFalcon marked this conversation as resolved.
Show resolved Hide resolved
self._attributes = attributes

def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try:
parsed_url = urlparse(config.host)
if parsed_url.scheme in ['http', 'https']:
hosts = f'{config.host}:{config.port}'
else:
hosts = f'http://{config.host}:{config.port}'
client = Elasticsearch(
hosts=f'{config.host}:{config.port}',
hosts=hosts,
basic_auth=(config.username, config.password),
request_timeout=100000,
retry_on_timeout=True,
Expand All @@ -53,42 +66,27 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:

return client

def _get_version(self) -> str:
info = self._client.info()
crazywoola marked this conversation as resolved.
Show resolved Hide resolved
return info['version']['number']

def _check_version(self):
if self._version < '8.0.0':
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")

def get_type(self) -> str:
return 'elasticsearch'

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]

if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mapping = {
"properties": {
"text": {
"type": "text"
},
"vector": {
"type": "dense_vector",
"index": True,
"dims": dim,
"similarity": "l2_norm"
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mapping)

added_ids = []
for i, text in enumerate(texts):
for i in range(len(documents)):
self._client.index(index=self._collection_name,
id=uuids[i],
document={
"text": text,
"vector": embeddings[i] if embeddings[i] else None,
"metadata": metadatas[i] if metadatas[i] else {},
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
})
added_ids.append(uuids[i])

self._client.indices.refresh(index=self._collection_name)
return uuids

Expand Down Expand Up @@ -116,28 +114,21 @@ def delete(self) -> None:
self._client.indices.delete(index=self._collection_name)

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_str = {
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {
"query_vector": query_vector
}
}
}
}
top_k = kwargs.get("top_k", 10)
knn = {
"field": Field.VECTOR.value,
"query_vector": query_vector,
"k": top_k
}

results = self._client.search(index=self._collection_name, body=query_str)
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)

docs_and_scores = []
for hit in results['hits']['hits']:
docs_and_scores.append(
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))

docs = []
for doc, score in docs_and_scores:
Expand All @@ -146,25 +137,61 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
doc.metadata['score'] = score
docs.append(doc)

# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)

return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = {
"match": {
"text": query
Field.CONTENT_KEY.value: query
}
}
results = self._client.search(index=self._collection_name, query=query_str)
docs = []
for hit in results['hits']['hits']:
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
docs.append(Document(
page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value],
))

return docs

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
return self.add_texts(texts, embeddings, **kwargs)
metadatas = [d.metadata for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)

def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name}'
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return

if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"similarity": "cosine"
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
}
}
self._client.indices.create(index=self._collection_name, mappings=mappings)

redis_client.set(collection_exist_cache_key, 1, ex=3600)


class ElasticSearchVectorFactory(AbstractVectorFactory):
Expand Down