Skip to content

Commit

Permalink
Batch cherry-pick from master (#1253)
Browse files Browse the repository at this point in the history
- Add get server version (#1240)
- Remove round when parsing query result (#1244)
- Use already exists schema_dict (#1249)
- Add timeout for all APIs in utitlity (#1250)

See also: milvus-io/milvus#20870, #1248

Signed-off-by: yangxuan <[email protected]>

Signed-off-by: yangxuan <[email protected]>
  • Loading branch information
XuanYang-cn authored Dec 12, 2022
1 parent b151592 commit 635654e
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 50 deletions.
5 changes: 3 additions & 2 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc

import numpy as np
from .configs import DefaultConfigs
from .types import DataType
from .constants import DEFAULT_CONSISTENCY_LEVEL
Expand Down Expand Up @@ -277,7 +278,7 @@ def get__item(self, item):
entity_row_data[field_data.field_name] = field_data.scalars.long_data.data[item]
elif field_data.type == DataType.FLOAT:
if len(field_data.scalars.float_data.data) >= item:
entity_row_data[field_data.field_name] = round(field_data.scalars.float_data.data[item], 6)
entity_row_data[field_data.field_name] = np.single(field_data.scalars.float_data.data[item])
elif field_data.type == DataType.DOUBLE:
if len(field_data.scalars.double_data.data) >= item:
entity_row_data[field_data.field_name] = field_data.scalars.double_data.data[item]
Expand All @@ -292,7 +293,7 @@ def get__item(self, item):
if len(field_data.vectors.float_vector.data) >= item * dim:
start_pos = item * dim
end_pos = item * dim + dim
entity_row_data[field_data.field_name] = [round(x, 6) for x in
entity_row_data[field_data.field_name] = [np.single(x) for x in
field_data.vectors.float_vector.data[
start_pos:end_pos]]
elif field_data.type == DataType.BINARY_VECTOR:
Expand Down
18 changes: 14 additions & 4 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from urllib import parse

import grpc
import numpy as np
from grpc._cython import cygrpc

from ..grpc_gen import milvus_pb2_grpc
Expand Down Expand Up @@ -189,11 +190,11 @@ def server_address(self):
""" Server network address """
return self._address

def reset_password(self, user, old_password, new_password):
def reset_password(self, user, old_password, new_password, timeout=None):
"""
reset password and then setup the grpc channel.
"""
self.update_password(user, old_password, new_password)
self.update_password(user, old_password, new_password, timeout=timeout)
self._setup_authorization_interceptor(user, new_password)
self._setup_grpc_channel()

Expand Down Expand Up @@ -906,7 +907,7 @@ def query(self, collection_name, expr, output_fields=None, partition_names=None,
elif field_data.type == DataType.INT64:
result[field_data.field_name] = field_data.scalars.long_data.data[index]
elif field_data.type == DataType.FLOAT:
result[field_data.field_name] = round(field_data.scalars.float_data.data[index], 6)
result[field_data.field_name] = np.single(field_data.scalars.float_data.data[index])
elif field_data.type == DataType.DOUBLE:
result[field_data.field_name] = field_data.scalars.double_data.data[index]
elif field_data.type == DataType.VARCHAR:
Expand All @@ -918,7 +919,7 @@ def query(self, collection_name, expr, output_fields=None, partition_names=None,
dim = field_data.vectors.dim
start_pos = index * dim
end_pos = index * dim + dim
result[field_data.field_name] = [round(x, 6) for x in
result[field_data.field_name] = [np.single(x) for x in
field_data.vectors.float_vector.data[start_pos:end_pos]]
elif field_data.type == DataType.BINARY_VECTOR:
dim = field_data.vectors.dim
Expand Down Expand Up @@ -1174,3 +1175,12 @@ def select_grant_for_role_and_object(self, role_name, object, object_name, timeo
raise MilvusException(resp.status.error_code, resp.status.reason)

return GrantInfo(resp.entities)

@retry_on_rpc_failure()
def get_server_version(self, timeout=None, **kwargs) -> str:
req = Prepare.get_server_version()
resp = self._stub.GetVersion(req, timeout=timeout)
if resp.status.error_code != 0:
raise MilvusException(resp.status.error_code, resp.status.reason)

return resp.version
4 changes: 4 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,3 +800,7 @@ def select_grant_request(cls, role_name, object, object_name):
entity=milvus_types.GrantEntity(role=milvus_types.RoleEntity(name=role_name),
object=milvus_types.ObjectEntity(name=object) if object else None,
object_name=object_name if object_name else None))

@classmethod
def get_server_version(cls):
return milvus_types.GetVersionRequest()
18 changes: 8 additions & 10 deletions pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,11 @@ def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeo
if data is None:
return MutationResult(data)
check_insert_data_schema(self._schema, data)
entities = Prepare.prepare_insert_data(data, self._schema)

conn = self._get_connection()
entities = Prepare.prepare_insert_data(data, self._schema)
schema_dict = self._schema.to_dict()
schema_dict["consistency_level"] = self._consistency_level
res = conn.batch_insert(self._name, entities, partition_name, timeout=timeout, schema=schema_dict, **kwargs)
res = conn.batch_insert(self._name, entities, partition_name,
timeout=timeout, schema=self._schema_dict, **kwargs)

if kwargs.get("_async", False):
return MutationFuture(res)
Expand Down Expand Up @@ -609,7 +608,7 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None
conn = self._get_connection()
res = conn.search(self._name, data, anns_field, param, limit, expr,
partition_names, output_fields, round_decimal, timeout=timeout,
collection_schema=self._schema_dict, **kwargs)
schema=self._schema_dict, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)
Expand Down Expand Up @@ -691,9 +690,8 @@ def query(self, expr, output_fields=None, partition_names=None, timeout=None, **
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr))

conn = self._get_connection()
schema = self._schema.to_dict()
schema["consistency_level"] = self._consistency_level
res = conn.query(self._name, expr, output_fields, partition_names, timeout=timeout, schema=schema, **kwargs)
res = conn.query(self._name, expr, output_fields, partition_names,
timeout=timeout, schema=self._schema_dict, **kwargs)
return res

@property
Expand Down Expand Up @@ -1083,6 +1081,6 @@ def get_replicas(self, timeout=None, **kwargs) -> Replica:
conn = self._get_connection()
return conn.get_replicas(self.name, timeout=timeout, **kwargs)

def describe(self):
def describe(self, timeout=None):
conn = self._get_connection()
return conn.describe_collection(self.name)
return conn.describe_collection(self.name, timeout=timeout)
17 changes: 7 additions & 10 deletions pymilvus/orm/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(self, collection, name, description="", **kwargs):
if not has:
conn.create_partition(self._collection.name, self._name, **copy_kwargs)

self._schema_dict = self._schema.to_dict()
self._schema_dict["consistency_level"] = self._consistency_level

def __repr__(self):
return json.dumps({
'name': self.name,
Expand Down Expand Up @@ -294,10 +297,8 @@ def insert(self, data, timeout=None, **kwargs):
raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist)
# TODO: check insert data schema here?
entities = Prepare.prepare_insert_data(data, self._collection.schema)
schema_dict = self._schema.to_dict()
schema_dict["consistency_level"] = self._consistency_level
res = conn.batch_insert(self._collection.name, entities=entities,
partition_name=self._name, timeout=timeout, orm=True, schema=schema_dict, **kwargs)
res = conn.batch_insert(self._collection.name, entities=entities, partition_name=self._name,
timeout=timeout, orm=True, schema=self._schema_dict, **kwargs)
if kwargs.get("_async", False):
return MutationFuture(res)
return MutationResult(res)
Expand Down Expand Up @@ -476,10 +477,8 @@ def search(self, data, anns_field, param, limit,
- Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385
"""
conn = self._get_connection()
schema_dict = self._schema.to_dict()
schema_dict["consistency_level"] = self._consistency_level
res = conn.search(self._collection.name, data, anns_field, param, limit, expr, [self._name], output_fields,
round_decimal=round_decimal, timeout=timeout, collection_schema=schema_dict, **kwargs)
round_decimal=round_decimal, timeout=timeout, schema=self._schema_dict, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)
Expand Down Expand Up @@ -552,10 +551,8 @@ def query(self, expr, output_fields=None, timeout=None, **kwargs):
- Query results: [{'film_id': 0, 'film_date': 2000}, {'film_id': 1, 'film_date': 2001}]
"""
conn = self._get_connection()
schema_dict = self._schema.to_dict()
schema_dict["consistency_level"] = self._consistency_level
res = conn.query(self._collection.name, expr, output_fields, [self._name],
timeout=timeout, schema=schema_dict, **kwargs)
timeout=timeout, schema=self._schema_dict, **kwargs)
return res

def get_replicas(self, timeout=None, **kwargs) -> Replica:
Expand Down
63 changes: 39 additions & 24 deletions pymilvus/orm/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _get_connection(alias):
return connections._fetch_handler(alias)


def loading_progress(collection_name, partition_names=None, using="default"):
def loading_progress(collection_name, partition_names=None, using="default", timeout=None):
""" Show loading progress of sealed segments in percentage.
:param collection_name: The name of collection is loading
Expand Down Expand Up @@ -198,7 +198,7 @@ def loading_progress(collection_name, partition_names=None, using="default"):
>>> utility.loading_progress("test_loading_progress")
{'loading_progress': '100%'}
"""
progress = _get_connection(using).get_loading_progress(collection_name, partition_names)
progress = _get_connection(using).get_loading_progress(collection_name, partition_names, timeout=timeout)
return {
"loading_progress": f"{progress:.0f}%",
}
Expand Down Expand Up @@ -241,7 +241,7 @@ def wait_for_loading_complete(collection_name, partition_names=None, timeout=Non
return _get_connection(using).wait_for_loading_partitions(collection_name, partition_names, timeout=timeout)


def index_building_progress(collection_name, index_name="", using="default"):
def index_building_progress(collection_name, index_name="", using="default", timeout=None):
"""
Show # indexed entities vs. # total entities.
Expand Down Expand Up @@ -283,7 +283,8 @@ def index_building_progress(collection_name, index_name="", using="default"):
>>> index = c.create_index(field_name="float_vector", index_params=index_params, index_name="ivf_flat")
>>> utility.index_building_progress("test_collection", c.name)
"""
return _get_connection(using).get_index_build_progress(collection_name=collection_name, index_name=index_name)
return _get_connection(using).get_index_build_progress(
collection_name=collection_name, index_name=index_name, timeout=timeout)


def wait_for_index_building_complete(collection_name, index_name="", timeout=None, using="default"):
Expand Down Expand Up @@ -332,7 +333,7 @@ def wait_for_index_building_complete(collection_name, index_name="", timeout=Non
return _get_connection(using).wait_for_creating_index(collection_name, index_name, timeout=timeout)[0]


def has_collection(collection_name, using="default"):
def has_collection(collection_name, using="default", timeout=None):
"""
Checks whether a specified collection exists.
Expand All @@ -352,10 +353,10 @@ def has_collection(collection_name, using="default"):
>>> collection = Collection(name="test_collection", schema=schema)
>>> utility.has_collection("test_collection")
"""
return _get_connection(using).has_collection(collection_name)
return _get_connection(using).has_collection(collection_name, timeout=timeout)


def has_partition(collection_name, partition_name, using="default"):
def has_partition(collection_name, partition_name, using="default", timeout=None):
"""
Checks if a specified partition exists in a collection.
Expand All @@ -378,7 +379,7 @@ def has_partition(collection_name, partition_name, using="default"):
>>> collection = Collection(name="test_collection", schema=schema)
>>> utility.has_partition("_default")
"""
return _get_connection(using).has_partition(collection_name, partition_name)
return _get_connection(using).has_partition(collection_name, partition_name, timeout=timeout)


def drop_collection(collection_name, timeout=None, using="default"):
Expand Down Expand Up @@ -725,7 +726,7 @@ def list_bulk_insert_tasks(limit=0, collection_name=None, timeout=None, using="d
return _get_connection(using).list_bulk_insert_tasks(limit, collection_name, timeout=timeout, **kwargs)


def reset_password(user: str, old_password: str, new_password: str, using="default"):
def reset_password(user: str, old_password: str, new_password: str, using="default", timeout=None):
"""
Reset the user & password of the connection.
You must provide the original password to check if the operation is valid.
Expand All @@ -745,10 +746,10 @@ def reset_password(user: str, old_password: str, new_password: str, using="defau
>>> users = utility.list_usernames()
>>> print(f"users in Milvus: {users}")
"""
return _get_connection(using).reset_password(user, old_password, new_password)
return _get_connection(using).reset_password(user, old_password, new_password, timeout=timeout)


def create_user(user: str, password: str, using="default"):
def create_user(user: str, password: str, using="default", timeout=None):
""" Create User using the given user and password.
:param user: the user name.
:type user: str
Expand All @@ -763,10 +764,10 @@ def create_user(user: str, password: str, using="default"):
>>> users = utility.list_usernames()
>>> print(f"users in Milvus: {users}")
"""
return _get_connection(using).create_user(user, password)
return _get_connection(using).create_user(user, password, timeout=timeout)


def update_password(user: str, old_password, new_password: str, using="default"):
def update_password(user: str, old_password, new_password: str, using="default", timeout=None):
"""
Update user password using the given user and password.
You must provide the original password to check if the operation is valid.
Expand All @@ -788,10 +789,10 @@ def update_password(user: str, old_password, new_password: str, using="default")
>>> users = utility.list_usernames()
>>> print(f"users in Milvus: {users}")
"""
return _get_connection(using).update_password(user, old_password, new_password)
return _get_connection(using).update_password(user, old_password, new_password, timeout=timeout)


def delete_user(user: str, using="default"):
def delete_user(user: str, using="default", timeout=None):
""" Delete User corresponding to the username.
:param user: the user name.
:type user: str
Expand All @@ -803,10 +804,10 @@ def delete_user(user: str, using="default"):
>>> users = utility.list_usernames()
>>> print(f"users in Milvus: {users}")
"""
return _get_connection(using).delete_user(user)
return _get_connection(using).delete_user(user, timeout=timeout)


def list_usernames(using="default"):
def list_usernames(using="default", timeout=None):
""" List all usernames.
:return list of str:
The usernames in Milvus instances.
Expand All @@ -817,10 +818,10 @@ def list_usernames(using="default"):
>>> users = utility.list_usernames()
>>> print(f"users in Milvus: {users}")
"""
return _get_connection(using).list_usernames()
return _get_connection(using).list_usernames(timeout=timeout)


def list_roles(include_user_info: bool, using="default"):
def list_roles(include_user_info: bool, using="default", timeout=None):
""" List All Role Info
:param include_user_info: whether to obtain the user information associated with roles
:type include_user_info: bool
Expand All @@ -832,10 +833,10 @@ def list_roles(include_user_info: bool, using="default"):
>>> roles = utility.list_roles()
>>> print(f"roles in Milvus: {roles}")
"""
return _get_connection(using).select_all_role(include_user_info)
return _get_connection(using).select_all_role(include_user_info, timeout=timeout)


def list_user(username: str, include_role_info: bool, using="default"):
def list_user(username: str, include_role_info: bool, using="default", timeout=None):
""" List One User Info
:param username: user name.
:type username: str
Expand All @@ -849,10 +850,10 @@ def list_user(username: str, include_role_info: bool, using="default"):
>>> user = utility.list_user(username, include_role_info)
>>> print(f"user info: {user}")
"""
return _get_connection(using).select_one_user(username, include_role_info)
return _get_connection(using).select_one_user(username, include_role_info, timeout=timeout)


def list_users(include_role_info: bool, using="default"):
def list_users(include_role_info: bool, using="default", timeout=None):
""" List All User Info
:param include_role_info: whether to obtain the role information associated with users
:type include_role_info: bool
Expand All @@ -864,4 +865,18 @@ def list_users(include_role_info: bool, using="default"):
>>> users = utility.list_users(include_role_info)
>>> print(f"users info: {users}")
"""
return _get_connection(using).select_all_user(include_role_info)
return _get_connection(using).select_all_user(include_role_info, timeout=timeout)

def get_server_version(using="default", timeout=None) -> str:
""" get the running server's version
:returns: server's version
:rtype: str
:example:
>>> from pymilvus import connections, utility
>>> connections.connect()
>>> utility.get_server_version()
>>> "2.2.0"
"""
return _get_connection(using).get_server_version(timeout=timeout)
Loading

0 comments on commit 635654e

Please sign in to comment.