Skip to content

Commit

Permalink
Avoiding parameter conflicts and adding compatibility logic for prepa…
Browse files Browse the repository at this point in the history
…re_index_params (#1954)

Signed-off-by: zhenshan.cao <[email protected]>
  • Loading branch information
czs007 authored Mar 8, 2024
1 parent 5b5e361 commit f4c3466
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 26 deletions.
70 changes: 70 additions & 0 deletions examples/milvus_client/customize_schema_auto_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import time
import numpy as np
from pymilvus import (
MilvusClient,
DataType
)

fmt = "\n=== {:30} ===\n"
dim = 8
collection_name = "hello_milvus"
milvus_client = MilvusClient("http://localhost:19530")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)

schema = milvus_client.create_schema(enable_dynamic_field=True, auto_id=True)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("title", DataType.VARCHAR, max_length=64)


index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", metric_type="L2")
milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong")

print(fmt.format(" all collections "))
print(milvus_client.list_collections())

print(fmt.format(f"schema of collection {collection_name}"))
print(milvus_client.describe_collection(collection_name))

rng = np.random.default_rng(seed=19530)
rows = [
{"embeddings": rng.random((1, dim))[0], "a": 100, "title": "t1"},
{"embeddings": rng.random((1, dim))[0], "b": 200, "title": "t2"},
{"embeddings": rng.random((1, dim))[0], "c": 300, "title": "t3"},
{"embeddings": rng.random((1, dim))[0], "d": 400, "title": "t4"},
{"embeddings": rng.random((1, dim))[0], "e": 500, "title": "t5"},
{"embeddings": rng.random((1, dim))[0], "f": 600, "title": "t6"},
]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows)
print(fmt.format("Inserting entities done"))
print(insert_result)


print(fmt.format("Start load collection "))
milvus_client.load_collection(collection_name)

print(fmt.format("Start query by specifying primary keys"))
query_results = milvus_client.query(collection_name, ids=insert_result['ids'][0])
print(query_results[0])

print(fmt.format("Start query by specifying filtering expression"))
query_results = milvus_client.query(collection_name, filter= "f == 600 or title == 't2'")
for ret in query_results:
print(ret)

rng = np.random.default_rng(seed=19530)
vectors_to_search = rng.random((1, dim))

print(fmt.format(f"Start search with retrieve serveral fields."))
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["pk", "a", "b"])
for hits in result:
for hit in hits:
print(f"hit: {hit}")

milvus_client.drop_collection(collection_name)
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dimension)
schema.add_field("info", DataType.JSON)

index_param = client.prepare_index_params("embeddings", metric_type="L2")
client.create_collection_with_schema(collection_name, schema, index_param)
index_params = client.prepare_index_params("embeddings", metric_type="L2")
client.create_collection(collection_name, schema=schema, index_params=index_params)

rng = np.random.default_rng(seed=19530)
rows = [
Expand Down
8 changes: 4 additions & 4 deletions examples/milvus_client/rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@

for role in [role_db_rw, role_db_ro]:
if role in current_roles:
privileges = milvus_client.describe_role(role)
for item in privileges:
role_info = milvus_client.describe_role(role)
for item in role_info['privileges']:
milvus_client.revoke_privilege(role, item["object_type"], item["privilege"], item["object_name"])

milvus_client.drop_role(role)
Expand All @@ -79,8 +79,8 @@
roles = milvus_client.list_roles()
print("roles:", roles)
for role in roles:
privileges = milvus_client.describe_role(role)
print(f"privileges for {role}:", privileges)
role_info = milvus_client.describe_role(role)
print(f"info for {role}:", role_info)


user1_info = milvus_client.describe_user("user1")
Expand Down
5 changes: 2 additions & 3 deletions pymilvus/bulk_writer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,9 @@ def persist(self, local_path: str, **kwargs) -> list:
if row_count < 0:
row_count = len(self._buffer[k])
elif row_count != len(self._buffer[k]):
buffer_k_len = len(self._buffer[k])
self._throw(
"Column `{}` row count {} doesn't equal to the first column row count {}".format(
k, len(self._buffer[k]), row_count
)
f"Column {k} row count {buffer_k_len} doesn't equal to the first column row count {row_count}"
)

# output files
Expand Down
4 changes: 3 additions & 1 deletion pymilvus/milvus_client/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def __eq__(self, other: None):


class IndexParams:
def __init__(self):
def __init__(self, field_name: str = "", **kwargs):
self._indexes = {}
if field_name:
self.add_index(field_name, **kwargs)

def add_index(self, field_name: str, index_type: str = "", index_name: str = "", **kwargs):
index_param = IndexParam(field_name, index_type, index_name, **kwargs)
Expand Down
15 changes: 5 additions & 10 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _fast_create_collection(
logger.error("Failed to create collection: %s", collection_name)
raise ex from ex

index_params = self.prepare_index_params()
index_params = IndexParams()
index_params.add_index(vector_field_name, "", "", metric_type=metric_type)
self.create_index(collection_name, index_params, timeout=timeout)
self.load_collection(collection_name, timeout=timeout)
Expand Down Expand Up @@ -278,6 +278,7 @@ def search(
search_params: Optional[dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
**kwargs,
) -> List[List[dict]]:
"""Search for a query vector/vectors.
Expand Down Expand Up @@ -307,7 +308,7 @@ def search(
res = conn.search(
collection_name,
data,
"",
anns_field or "",
search_params or {},
expression=filter,
limit=limit,
Expand Down Expand Up @@ -593,8 +594,8 @@ def create_schema(cls, **kwargs):
return CollectionSchema([], **kwargs)

@classmethod
def prepare_index_params(cls):
return IndexParams()
def prepare_index_params(cls, field_name: str = "", **kwargs):
return IndexParams(field_name, **kwargs)

def _create_collection_with_schema(
self,
Expand All @@ -605,12 +606,6 @@ def _create_collection_with_schema(
**kwargs,
):
schema.verify()
if kwargs.get("auto_id", False):
schema.auto_id = True
schema.verify()

if schema.enable_dynamic_field is None:
schema.enable_dynamic_field = kwargs.get("enable_dynamic_field", False)

conn = self._get_connection()
if "consistency_level" not in kwargs:
Expand Down
8 changes: 2 additions & 6 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,10 @@ def auto_id(self, value: bool):

@property
def enable_dynamic_field(self):
return self._enable_dynamic_field
return bool(self._enable_dynamic_field)

@enable_dynamic_field.setter
def enable_dynamic_field(self, value: bool):
if value is None:
# we keep None here
self._enable_dynamic_field = value
return
self._enable_dynamic_field = bool(value)

def to_dict(self):
Expand All @@ -239,7 +235,7 @@ def to_dict(self):
"fields": [s.to_dict() for s in self._fields],
}
if self._enable_dynamic_field is not None:
_dict["enable_dynamic_field"] = self._enable_dynamic_field
_dict["enable_dynamic_field"] = self.enable_dynamic_field
return _dict

def verify(self):
Expand Down

0 comments on commit f4c3466

Please sign in to comment.