From f4c34668aba286ba15913311c2b96f18a13fa1e7 Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Fri, 8 Mar 2024 10:56:54 +0800 Subject: [PATCH] Avoiding parameter conflicts and adding compatibility logic for prepare_index_params (#1954) Signed-off-by: zhenshan.cao --- .../milvus_client/customize_schema_auto_id.py | 70 +++++++++++++++++++ .../{ => milvus_client}/non_ascii_encode.py | 4 +- examples/milvus_client/rbac.py | 8 +-- pymilvus/bulk_writer/buffer.py | 5 +- pymilvus/milvus_client/index.py | 4 +- pymilvus/milvus_client/milvus_client.py | 15 ++-- pymilvus/orm/schema.py | 8 +-- 7 files changed, 88 insertions(+), 26 deletions(-) create mode 100644 examples/milvus_client/customize_schema_auto_id.py rename examples/{ => milvus_client}/non_ascii_encode.py (87%) diff --git a/examples/milvus_client/customize_schema_auto_id.py b/examples/milvus_client/customize_schema_auto_id.py new file mode 100644 index 000000000..4d727b9aa --- /dev/null +++ b/examples/milvus_client/customize_schema_auto_id.py @@ -0,0 +1,70 @@ +import time +import numpy as np +from pymilvus import ( + MilvusClient, + DataType +) + +fmt = "\n=== {:30} ===\n" +dim = 8 +collection_name = "hello_milvus" +milvus_client = MilvusClient("http://localhost:19530") + +has_collection = milvus_client.has_collection(collection_name, timeout=5) +if has_collection: + milvus_client.drop_collection(collection_name) + +schema = milvus_client.create_schema(enable_dynamic_field=True, auto_id=True) +schema.add_field("id", DataType.INT64, is_primary=True) +schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim) +schema.add_field("title", DataType.VARCHAR, max_length=64) + + +index_params = milvus_client.prepare_index_params() +index_params.add_index(field_name = "embeddings", metric_type="L2") +milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong") + +print(fmt.format(" all collections ")) +print(milvus_client.list_collections()) + +print(fmt.format(f"schema of collection {collection_name}")) +print(milvus_client.describe_collection(collection_name)) + +rng = np.random.default_rng(seed=19530) +rows = [ + {"embeddings": rng.random((1, dim))[0], "a": 100, "title": "t1"}, + {"embeddings": rng.random((1, dim))[0], "b": 200, "title": "t2"}, + {"embeddings": rng.random((1, dim))[0], "c": 300, "title": "t3"}, + {"embeddings": rng.random((1, dim))[0], "d": 400, "title": "t4"}, + {"embeddings": rng.random((1, dim))[0], "e": 500, "title": "t5"}, + {"embeddings": rng.random((1, dim))[0], "f": 600, "title": "t6"}, +] + +print(fmt.format("Start inserting entities")) +insert_result = milvus_client.insert(collection_name, rows) +print(fmt.format("Inserting entities done")) +print(insert_result) + + +print(fmt.format("Start load collection ")) +milvus_client.load_collection(collection_name) + +print(fmt.format("Start query by specifying primary keys")) +query_results = milvus_client.query(collection_name, ids=insert_result['ids'][0]) +print(query_results[0]) + +print(fmt.format("Start query by specifying filtering expression")) +query_results = milvus_client.query(collection_name, filter= "f == 600 or title == 't2'") +for ret in query_results: + print(ret) + +rng = np.random.default_rng(seed=19530) +vectors_to_search = rng.random((1, dim)) + +print(fmt.format(f"Start search with retrieve serveral fields.")) +result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["pk", "a", "b"]) +for hits in result: + for hit in hits: + print(f"hit: {hit}") + +milvus_client.drop_collection(collection_name) diff --git a/examples/non_ascii_encode.py b/examples/milvus_client/non_ascii_encode.py similarity index 87% rename from examples/non_ascii_encode.py rename to examples/milvus_client/non_ascii_encode.py index c96f440f1..a8ee76b06 100644 --- a/examples/non_ascii_encode.py +++ b/examples/milvus_client/non_ascii_encode.py @@ -11,8 +11,8 @@ schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dimension) schema.add_field("info", DataType.JSON) -index_param = client.prepare_index_params("embeddings", metric_type="L2") -client.create_collection_with_schema(collection_name, schema, index_param) +index_params = client.prepare_index_params("embeddings", metric_type="L2") +client.create_collection(collection_name, schema=schema, index_params=index_params) rng = np.random.default_rng(seed=19530) rows = [ diff --git a/examples/milvus_client/rbac.py b/examples/milvus_client/rbac.py index f7ecc1057..63ca5d005 100644 --- a/examples/milvus_client/rbac.py +++ b/examples/milvus_client/rbac.py @@ -59,8 +59,8 @@ for role in [role_db_rw, role_db_ro]: if role in current_roles: - privileges = milvus_client.describe_role(role) - for item in privileges: + role_info = milvus_client.describe_role(role) + for item in role_info['privileges']: milvus_client.revoke_privilege(role, item["object_type"], item["privilege"], item["object_name"]) milvus_client.drop_role(role) @@ -79,8 +79,8 @@ roles = milvus_client.list_roles() print("roles:", roles) for role in roles: - privileges = milvus_client.describe_role(role) - print(f"privileges for {role}:", privileges) + role_info = milvus_client.describe_role(role) + print(f"info for {role}:", role_info) user1_info = milvus_client.describe_user("user1") diff --git a/pymilvus/bulk_writer/buffer.py b/pymilvus/bulk_writer/buffer.py index 405616551..12c971d4e 100644 --- a/pymilvus/bulk_writer/buffer.py +++ b/pymilvus/bulk_writer/buffer.py @@ -108,10 +108,9 @@ def persist(self, local_path: str, **kwargs) -> list: if row_count < 0: row_count = len(self._buffer[k]) elif row_count != len(self._buffer[k]): + buffer_k_len = len(self._buffer[k]) self._throw( - "Column `{}` row count {} doesn't equal to the first column row count {}".format( - k, len(self._buffer[k]), row_count - ) + f"Column {k} row count {buffer_k_len} doesn't equal to the first column row count {row_count}" ) # output files diff --git a/pymilvus/milvus_client/index.py b/pymilvus/milvus_client/index.py index 94213bbeb..5c3a3f4ea 100644 --- a/pymilvus/milvus_client/index.py +++ b/pymilvus/milvus_client/index.py @@ -37,8 +37,10 @@ def __eq__(self, other: None): class IndexParams: - def __init__(self): + def __init__(self, field_name: str = "", **kwargs): self._indexes = {} + if field_name: + self.add_index(field_name, **kwargs) def add_index(self, field_name: str, index_type: str = "", index_name: str = "", **kwargs): index_param = IndexParam(field_name, index_type, index_name, **kwargs) diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index 145c33c7d..edb02eea8 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -132,7 +132,7 @@ def _fast_create_collection( logger.error("Failed to create collection: %s", collection_name) raise ex from ex - index_params = self.prepare_index_params() + index_params = IndexParams() index_params.add_index(vector_field_name, "", "", metric_type=metric_type) self.create_index(collection_name, index_params, timeout=timeout) self.load_collection(collection_name, timeout=timeout) @@ -278,6 +278,7 @@ def search( search_params: Optional[dict] = None, timeout: Optional[float] = None, partition_names: Optional[List[str]] = None, + anns_field: Optional[str] = None, **kwargs, ) -> List[List[dict]]: """Search for a query vector/vectors. @@ -307,7 +308,7 @@ def search( res = conn.search( collection_name, data, - "", + anns_field or "", search_params or {}, expression=filter, limit=limit, @@ -593,8 +594,8 @@ def create_schema(cls, **kwargs): return CollectionSchema([], **kwargs) @classmethod - def prepare_index_params(cls): - return IndexParams() + def prepare_index_params(cls, field_name: str = "", **kwargs): + return IndexParams(field_name, **kwargs) def _create_collection_with_schema( self, @@ -605,12 +606,6 @@ def _create_collection_with_schema( **kwargs, ): schema.verify() - if kwargs.get("auto_id", False): - schema.auto_id = True - schema.verify() - - if schema.enable_dynamic_field is None: - schema.enable_dynamic_field = kwargs.get("enable_dynamic_field", False) conn = self._get_connection() if "consistency_level" not in kwargs: diff --git a/pymilvus/orm/schema.py b/pymilvus/orm/schema.py index 5cbc0be9b..bbfcacedd 100644 --- a/pymilvus/orm/schema.py +++ b/pymilvus/orm/schema.py @@ -222,14 +222,10 @@ def auto_id(self, value: bool): @property def enable_dynamic_field(self): - return self._enable_dynamic_field + return bool(self._enable_dynamic_field) @enable_dynamic_field.setter def enable_dynamic_field(self, value: bool): - if value is None: - # we keep None here - self._enable_dynamic_field = value - return self._enable_dynamic_field = bool(value) def to_dict(self): @@ -239,7 +235,7 @@ def to_dict(self): "fields": [s.to_dict() for s in self._fields], } if self._enable_dynamic_field is not None: - _dict["enable_dynamic_field"] = self._enable_dynamic_field + _dict["enable_dynamic_field"] = self.enable_dynamic_field return _dict def verify(self):