Skip to content

Commit

Permalink
Added OceanBase as an option for the vector store in Dify (#10010)
Browse files Browse the repository at this point in the history
  • Loading branch information
powerfooI authored Oct 29, 2024
1 parent 5580bcf commit 878d13e
Show file tree
Hide file tree
Showing 17 changed files with 427 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ docker/volumes/unstructured/*
docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/*
docker/volumes/oceanbase/*

docker/nginx/conf.d/default.conf
docker/nginx/ssl/*
Expand Down
8 changes: 8 additions & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30

# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G

# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
Expand Down
1 change: 1 addition & 0 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def migrate_knowledge_vector_database():
VectorType.VIKINGDB,
VectorType.UPSTASH,
VectorType.COUCHBASE,
VectorType.OCEANBASE,
}
page = 1
while True:
Expand Down
2 changes: 2 additions & 0 deletions api/configs/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
from configs.middleware.vdb.oracle_config import OracleConfig
from configs.middleware.vdb.pgvector_config import PGVectorConfig
Expand Down Expand Up @@ -257,5 +258,6 @@ class MiddlewareConfig(
VikingDBConfig,
UpstashConfig,
TidbOnQdrantConfig,
OceanBaseVectorConfig,
):
pass
35 changes: 35 additions & 0 deletions api/configs/middleware/vdb/oceanbase_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings


class OceanBaseVectorConfig(BaseSettings):
"""
Configuration settings for OceanBase Vector database
"""

OCEANBASE_VECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')",
default=None,
)

OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field(
description="Port number on which the OceanBase Vector server is listening (default is 2881)",
default=2881,
)

OCEANBASE_VECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the OceanBase Vector database",
default=None,
)

OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the OceanBase Vector database",
default=None,
)

OCEANBASE_VECTOR_DATABASE: Optional[str] = Field(
description="Name of the OceanBase Vector database to connect to",
default=None,
)
2 changes: 2 additions & 0 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def get(self):
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
| VectorType.OCEANBASE
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
Expand Down Expand Up @@ -669,6 +670,7 @@ def get(self, vector_type):
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
| VectorType.OCEANBASE
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
Expand Down
Empty file.
209 changes: 209 additions & 0 deletions api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import json
import logging
import math
from typing import Any

from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient
from sqlalchemy import JSON, Column, String, func
from sqlalchemy.dialects.mysql import LONGTEXT

from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

logger = logging.getLogger(__name__)

DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"


class OceanBaseVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str

@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
if not values["port"]:
raise ValueError("config OCEANBASE_VECTOR_PORT is required")
if not values["user"]:
raise ValueError("config OCEANBASE_VECTOR_USER is required")
if not values["database"]:
raise ValueError("config OCEANBASE_VECTOR_DATABASE is required")
return values


class OceanBaseVector(BaseVector):
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
super().__init__(collection_name)
self._config = config
self._hnsw_ef_search = -1
self._client = ObVecClient(
uri=f"{self._config.host}:{self._config.port}",
user=self._config.user,
password=self._config.password,
db_name=self._config.database,
)

def get_type(self) -> str:
return VectorType.OCEANBASE

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._vec_dim = len(embeddings[0])
self._create_collection()
self.add_texts(texts, embeddings)

def _create_collection(self) -> None:
lock_name = "vector_indexing_lock_" + self._collection_name
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_" + self._collection_name
if redis_client.get(collection_exist_cache_key):
return

if self._client.check_table_exists(self._collection_name):
return

self.delete()

cols = [
Column("id", String(36), primary_key=True, autoincrement=False),
Column("vector", VECTOR(self._vec_dim)),
Column("text", LONGTEXT),
Column("metadata", JSON),
]
vidx_params = self._client.prepare_index_params()
vidx_params.add_index(
field_name="vector",
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
index_name="vector_index",
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
)

self._client.create_table_with_index_params(
table_name=self._collection_name,
columns=cols,
vidxs=vidx_params,
)
vals = []
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
for row in params:
val = int(row[6])
vals.append(val)
if len(vals) == 0:
print("ob_vector_memory_limit_percentage not found in parameters.")
exit(1)
if any(val == 0 for val in vals):
try:
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
except Exception as e:
raise Exception(
"Failed to set ob_vector_memory_limit_percentage. "
+ "Maybe the database user has insufficient privilege.",
e,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
for id, doc, emb in zip(ids, documents, embeddings):
self._client.insert(
table_name=self._collection_name,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)

def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, id=id)
return cur.rowcount != 0

def delete_by_ids(self, ids: list[str]) -> None:
self._client.delete(table_name=self._collection_name, ids=ids)

def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
cur = self._client.get(
table_name=self._collection_name,
where_clause=f"metadata->>'$.{key}' = '{value}'",
output_column_name=["id"],
)
return [row[0] for row in cur]

def delete_by_metadata_field(self, key: str, value: str) -> None:
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search)
self._hnsw_ef_search = ef_search
topk = kwargs.get("top_k", 10)
cur = self._client.ann_search(
table_name=self._collection_name,
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=func.l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
)
docs = []
for text, metadata, distance in cur:
metadata = json.loads(metadata)
metadata["score"] = 1 - distance / math.sqrt(2)
docs.append(
Document(
page_content=text,
metadata=metadata,
)
)
return docs

def delete(self) -> None:
self._client.drop_table_if_exist(self._collection_name)


class OceanBaseVectorFactory(AbstractVectorFactory):
def init_vector(
self,
dataset: Dataset,
attributes: list,
embeddings: Embeddings,
) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OCEANBASE, collection_name))
return OceanBaseVector(
collection_name,
OceanBaseVectorConfig(
host=dify_config.OCEANBASE_VECTOR_HOST,
port=dify_config.OCEANBASE_VECTOR_PORT,
user=dify_config.OCEANBASE_VECTOR_USER,
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE,
),
)
4 changes: 4 additions & 0 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory

return TidbOnQdrantVectorFactory
case VectorType.OCEANBASE:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory

return OceanBaseVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

Expand Down
1 change: 1 addition & 0 deletions api/core/rag/datasource/vdb/vector_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ class VectorType(str, Enum):
VIKINGDB = "vikingdb"
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"
Loading

0 comments on commit 878d13e

Please sign in to comment.