Skip to content

Commit

Permalink
Fix several bugs related to MilvusClient (#1920)
Browse files Browse the repository at this point in the history
Signed-off-by: zhenshan.cao <[email protected]>
  • Loading branch information
czs007 authored Feb 6, 2024
1 parent 0603224 commit ecf0ef3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 21 deletions.
2 changes: 1 addition & 1 deletion examples/milvus_client/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name = "embeddings", metric_type="L2")
index_params.add_index(field_name = "title", index_type = "TRIE", index_name="my_trie")
index_params.add_index(field_name = "title", index_type = "Trie", index_name="my_trie")

print(fmt.format("Start create index"))
milvus_client.create_index(collection_name, index_params)
Expand Down
9 changes: 5 additions & 4 deletions pymilvus/milvus_client/index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
class IndexParam:
def __init__(self, field_name: str, index_type: str, index_name: str, **kwargs):
self._field_name = field_name
self._index_name = index_name
self._index_type = index_type
self._index_name = index_name
self._kwargs = kwargs

@property
Expand All @@ -18,9 +18,10 @@ def index_type(self):
return self._index_type

def __iter__(self):
yield "field_name", self._field_name
yield "index_type", self._index_type
yield "index_name", self._index_name
yield "field_name", self.field_name
if self.index_type:
yield "index_type", self.index_type
yield "index_name", self.index_name
yield from self._kwargs.items()

def __str__(self):
Expand Down
42 changes: 26 additions & 16 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,7 @@ def _fast_create_collection(
raise ex from ex

index_params = self.prepare_index_params()
index_type = ""
index_name = ""
params = {"metric_type": metric_type}
index_params.add_index(vector_field_name, index_type, index_name, params=params)
index_params.add_index(vector_field_name, "", "", metric_type=metric_type)
self.create_index(collection_name, index_params, timeout=timeout)
self.load_collection(collection_name, timeout=timeout)

Expand All @@ -155,20 +152,16 @@ def _create_index(
):
conn = self._get_connection()
try:
params = index_param.get("params", {})
_index_type = index_param.get("index_type")
if _index_type:
params["index_type"] = _index_type
_metric_type = index_param.get("metric_type")
if _metric_type:
params["metric_type"] = _metric_type

params = index_param.pop("params", {})
field_name = index_param.pop("field_name", "")
index_name = index_param.pop("index_name", "")
params.update(index_param)
conn.create_index(
collection_name,
index_param["field_name"],
field_name,
params,
index_name=index_param.get("index_name", ""),
timeout=timeout,
index_name=index_name,
**kwargs,
)
logger.debug("Successfully created an index on collection: %s", collection_name)
Expand Down Expand Up @@ -499,11 +492,24 @@ def delete(
if isinstance(pks, (int, str)):
pks = [pks]

for pk in pks:
if not isinstance(pk, (int, str)):
msg = f"wrong type of argument pks, expect list, int or str, got '{type(pk).__name__}'"
raise TypeError(msg)

if ids is not None:
if isinstance(ids, (int, str)):
pks.append(ids)
elif isinstance(ids, list):
for id in ids:
if not isinstance(id, (int, str)):
msg = f"wrong type of argument ids, expect list, int or str, got '{type(id).__name__}'"
raise TypeError(msg)
pks.extend(ids)
else:
msg = f"wrong type of argument ids, expect list, int or str, got '{type(ids).__name__}'"
raise TypeError(msg)

expr = ""
conn = self._get_connection()
if pks:
Expand Down Expand Up @@ -549,7 +555,8 @@ def get_collection_stats(self, collection_name: str, timeout: Optional[float] =
conn = self._get_connection()
stats = conn.get_collection_stats(collection_name, timeout=timeout)
result = {stat.key: stat.value for stat in stats}
result["row_count"] = int(result["row_count"])
if "row_count" in result:
result["row_count"] = int(result["row_count"])
return result

def describe_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
Expand Down Expand Up @@ -818,7 +825,10 @@ def get_partition_stats(
msg = f"wrong type of argument 'partition_name', str expected, got '{type(partition_name).__name__}'"
raise TypeError(msg)
ret = conn.get_partition_stats(collection_name, partition_name, timeout=timeout, **kwargs)
return {stat.key: stat.value for stat in ret}
result = {stat.key: stat.value for stat in ret}
if "row_count" in result:
result["row_count"] = int(result["row_count"])
return

def create_user(self, user_name: str, password: str, timeout: Optional[float] = None, **kwargs):
conn = self._get_connection()
Expand Down

0 comments on commit ecf0ef3

Please sign in to comment.