Skip to content

Commit

Permalink
Add client aws_opensearch
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain authored and XuanYang-cn committed Jul 18, 2024
1 parent 3fdb298 commit c45876c
Show file tree
Hide file tree
Showing 9 changed files with 419 additions and 12 deletions.
25 changes: 13 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,19 @@ pip install vectordb-bench[pinecone]
```
All the database client supported

|Optional database client|install command|
|---------------|---------------|
|pymilvus(*default*)|`pip install vectordb-bench`|
|all|`pip install vectordb-bench[all]`|
|qdrant|`pip install vectordb-bench[qdrant]`|
|pinecone|`pip install vectordb-bench[pinecone]`|
|weaviate|`pip install vectordb-bench[weaviate]`|
|elastic|`pip install vectordb-bench[elastic]`|
|pgvector|`pip install vectordb-bench[pgvector]`|
|pgvecto.rs|`pip install vectordb-bench[pgvecto_rs]`|
|redis|`pip install vectordb-bench[redis]`|
|chromadb|`pip install vectordb-bench[chromadb]`|
| Optional database client | install command |
|--------------------------|---------------------------------------------|
| pymilvus(*default*) | `pip install vectordb-bench` |
| all | `pip install vectordb-bench[all]` |
| qdrant | `pip install vectordb-bench[qdrant]` |
| pinecone | `pip install vectordb-bench[pinecone]` |
| weaviate | `pip install vectordb-bench[weaviate]` |
| elastic | `pip install vectordb-bench[elastic]` |
| pgvector | `pip install vectordb-bench[pgvector]` |
| pgvecto.rs | `pip install vectordb-bench[pgvecto_rs]` |
| redis | `pip install vectordb-bench[redis]` |
| chromadb | `pip install vectordb-bench[chromadb]` |
| awsopensearch | `pip install vectordb-bench[awsopensearch]` |

### Run

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ all = [
"psycopg2",
"psycopg",
"psycopg-binary",
"opensearch-dsl==2.1.0",
"opensearch-py==2.6.0",
]

qdrant = [ "qdrant-client" ]
Expand All @@ -72,6 +74,7 @@ pgvector = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvecto_rs = [ "psycopg2" ]
redis = [ "redis" ]
chromadb = [ "chromadb" ]
awsopensearch = [ "awsopensearch" ]
zilliz_cloud = []

[project.urls]
Expand Down
13 changes: 13 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DB(Enum):
PgVectoRS = "PgVectoRS"
Redis = "Redis"
Chroma = "Chroma"
AWSOpenSearch = "OpenSearch"
Test = "test"


Expand Down Expand Up @@ -78,6 +79,10 @@ def init_cls(self) -> Type[VectorDB]:
from .chroma.chroma import ChromaClient
return ChromaClient

if self == DB.AWSOpenSearch:
from .aws_opensearch.aws_opensearch import AWSOpenSearch
return AWSOpenSearch

@property
def config_cls(self) -> Type[DBConfig]:
"""Import while in use"""
Expand Down Expand Up @@ -121,6 +126,10 @@ def config_cls(self) -> Type[DBConfig]:
from .chroma.config import ChromaConfig
return ChromaConfig

if self == DB.AWSOpenSearch:
from .aws_opensearch.config import AWSOpenSearchConfig
return AWSOpenSearchConfig

def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
if self == DB.Milvus:
from .milvus.config import _milvus_case_config
Expand Down Expand Up @@ -150,6 +159,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon
from .pgvecto_rs.config import _pgvecto_rs_case_config
return _pgvecto_rs_case_config.get(index_type)

if self == DB.AWSOpenSearch:
from .aws_opensearch.config import AWSOpenSearchIndexConfig
return AWSOpenSearchIndexConfig

# DB.Pinecone, DB.Chroma, DB.Redis
return EmptyDBCaseConfig

Expand Down
159 changes: 159 additions & 0 deletions vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import logging
from contextlib import contextmanager
import time
from typing import Iterable, Type
from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk

log = logging.getLogger(__name__)


class AWSOpenSearch(VectorDB):
def __init__(
self,
dim: int,
db_config: dict,
db_case_config: AWSOpenSearchIndexConfig,
index_name: str = "vdb_bench_index", # must be lowercase
id_col_name: str = "id",
vector_col_name: str = "embedding",
drop_old: bool = False,
**kwargs,
):
self.dim = dim
self.db_config = db_config
self.case_config = db_case_config
self.index_name = index_name
self.id_col_name = id_col_name
self.category_col_names = [
f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
]
self.vector_col_name = vector_col_name

log.info(f"AWS_OpenSearch client config: {self.db_config}")
client = OpenSearch(**self.db_config)
if drop_old:
log.info(f"AWS_OpenSearch client drop old index: {self.index_name}")
is_existed = client.indices.exists(index=self.index_name)
if is_existed:
client.indices.delete(index=self.index_name)
self._create_index(client)

@classmethod
def config_cls(cls) -> AWSOpenSearchConfig:
return AWSOpenSearchConfig

@classmethod
def case_config_cls(
cls, index_type: IndexType | None = None
) -> AWSOpenSearchIndexConfig:
return AWSOpenSearchIndexConfig

def _create_index(self, client: OpenSearch):
settings = {
"index": {
"knn": True,
# "number_of_shards": 5,
# "refresh_interval": "600s",
}
}
mappings = {
"properties": {
self.id_col_name: {"type": "integer"},
**{
categoryCol: {"type": "keyword"}
for categoryCol in self.category_col_names
},
self.vector_col_name: {
"type": "knn_vector",
"dimension": self.dim,
"method": self.case_config.index_param(),
},
}
}
try:
client.indices.create(
index=self.index_name, body=dict(settings=settings, mappings=mappings)
)
except Exception as e:
log.warning(f"Failed to create index: {self.index_name} error: {str(e)}")
raise e from None

@contextmanager
def init(self) -> None:
"""connect to elasticsearch"""
self.client = OpenSearch(**self.db_config)

yield
# self.client.transport.close()
self.client = None
del self.client

def insert_embeddings(
self,
embeddings: Iterable[list[float]],
metadata: list[int],
**kwargs,
) -> tuple[int, Exception]:
"""Insert the embeddings to the elasticsearch."""
assert self.client is not None, "should self.init() first"

insert_data = []
for i in range(len(embeddings)):
insert_data.append({"index": {"_index": self.index_name, "_id": metadata[i]}})
insert_data.append({self.vector_col_name: embeddings[i]})
try:
resp = self.client.bulk(insert_data)
log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
resp = self.client.indices.stats(self.index_name)
log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}")
return (len(embeddings), None)
except Exception as e:
log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}")
time.sleep(10)
return self.insert_embeddings(embeddings, metadata)

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
) -> list[int]:
"""Get k most similar embeddings to query vector.
Args:
query(list[float]): query embedding to look up documents similar to.
k(int): Number of most similar embeddings to return. Defaults to 100.
filters(dict, optional): filtering expression to filter the data while searching.
Returns:
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
"""
assert self.client is not None, "should self.init() first"

body = {
"size": k,
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
}
try:
resp = self.client.search(index=self.index_name, body=body)
log.info(f'Search took: {resp["took"]}')
log.info(f'Search shards: {resp["_shards"]}')
log.info(f'Search hits total: {resp["hits"]["total"]}')
result = [int(d["_id"]) for d in resp["hits"]["hits"]]
# log.info(f'success! length={len(res)}')

return result
except Exception as e:
log.warning(f"Failed to search: {self.index_name} error: {str(e)}")
raise e from None

def optimize(self):
"""optimize will be called between insertion and search in performance cases."""
pass

def ready_to_load(self):
"""ready_to_load will be called before load in load cases."""
pass
44 changes: 44 additions & 0 deletions vectordb_bench/backend/clients/aws_opensearch/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Annotated, TypedDict, Unpack

import click
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
HNSWFlavor2,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from .. import DB


class AWSOpenSearchTypedDict(TypedDict):
host: Annotated[
str, click.option("--host", type=str, help="Db host", required=True)
]
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
password: Annotated[str, click.option("--password", type=str, help="Db password")]


class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
...


@cli.command()
@click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
run(
db=DB.AWSOpenSearch,
db_config=AWSOpenSearchConfig(
host=parameters["host"],
port=parameters["port"],
user=parameters["user"],
password=SecretStr(parameters["password"]),
),
db_case_config=AWSOpenSearchIndexConfig(
),
**parameters,
)
58 changes: 58 additions & 0 deletions vectordb_bench/backend/clients/aws_opensearch/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from enum import Enum
from pydantic import SecretStr, BaseModel

from ..api import DBConfig, DBCaseConfig, MetricType, IndexType


class AWSOpenSearchConfig(DBConfig, BaseModel):
host: str = ""
port: int = 443
user: str = ""
password: SecretStr = ""

def to_dict(self) -> dict:
return {
"hosts": [{'host': self.host, 'port': self.port}],
"http_auth": (self.user, self.password.get_secret_value()),
"use_ssl": True,
"http_compress": True,
"verify_certs": True,
"ssl_assert_hostname": False,
"ssl_show_warn": False,
"timeout": 600,
}


class AWSOS_Engine(Enum):
nmslib = "nmslib"
faiss = "faiss"
lucene = "Lucene"


class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType = MetricType.L2
engine: AWSOS_Engine = AWSOS_Engine.nmslib
efConstruction: int = 360
M: int = 30

def parse_metric(self) -> str:
if self.metric_type == MetricType.IP:
return "innerproduct" # only support faiss / nmslib, not for Lucene.
elif self.metric_type == MetricType.COSINE:
return "cosinesimil"
return "l2"

def index_param(self) -> dict:
params = {
"name": "hnsw",
"space_type": self.parse_metric(),
"engine": self.engine.value,
"parameters": {
"ef_construction": self.efConstruction,
"m": self.M
}
}
return params

def search_param(self) -> dict:
return {}
Loading

0 comments on commit c45876c

Please sign in to comment.