diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index f2e0dd0e1..b83f9cdfb 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -1,5 +1,6 @@ import abc +import numpy as np from .configs import DefaultConfigs from .types import DataType from .constants import DEFAULT_CONSISTENCY_LEVEL @@ -277,7 +278,7 @@ def get__item(self, item): entity_row_data[field_data.field_name] = field_data.scalars.long_data.data[item] elif field_data.type == DataType.FLOAT: if len(field_data.scalars.float_data.data) >= item: - entity_row_data[field_data.field_name] = round(field_data.scalars.float_data.data[item], 6) + entity_row_data[field_data.field_name] = np.single(field_data.scalars.float_data.data[item]) elif field_data.type == DataType.DOUBLE: if len(field_data.scalars.double_data.data) >= item: entity_row_data[field_data.field_name] = field_data.scalars.double_data.data[item] @@ -292,7 +293,7 @@ def get__item(self, item): if len(field_data.vectors.float_vector.data) >= item * dim: start_pos = item * dim end_pos = item * dim + dim - entity_row_data[field_data.field_name] = [round(x, 6) for x in + entity_row_data[field_data.field_name] = [np.single(x) for x in field_data.vectors.float_vector.data[ start_pos:end_pos]] elif field_data.type == DataType.BINARY_VECTOR: diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index d0e05d1c5..0671c2653 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -5,6 +5,7 @@ from urllib import parse import grpc +import numpy as np from grpc._cython import cygrpc from ..grpc_gen import milvus_pb2_grpc @@ -189,11 +190,11 @@ def server_address(self): """ Server network address """ return self._address - def reset_password(self, user, old_password, new_password): + def reset_password(self, user, old_password, new_password, timeout=None): """ reset password and then setup the grpc channel. """ - self.update_password(user, old_password, new_password) + self.update_password(user, old_password, new_password, timeout=timeout) self._setup_authorization_interceptor(user, new_password) self._setup_grpc_channel() @@ -906,7 +907,7 @@ def query(self, collection_name, expr, output_fields=None, partition_names=None, elif field_data.type == DataType.INT64: result[field_data.field_name] = field_data.scalars.long_data.data[index] elif field_data.type == DataType.FLOAT: - result[field_data.field_name] = round(field_data.scalars.float_data.data[index], 6) + result[field_data.field_name] = np.single(field_data.scalars.float_data.data[index]) elif field_data.type == DataType.DOUBLE: result[field_data.field_name] = field_data.scalars.double_data.data[index] elif field_data.type == DataType.VARCHAR: @@ -918,7 +919,7 @@ def query(self, collection_name, expr, output_fields=None, partition_names=None, dim = field_data.vectors.dim start_pos = index * dim end_pos = index * dim + dim - result[field_data.field_name] = [round(x, 6) for x in + result[field_data.field_name] = [np.single(x) for x in field_data.vectors.float_vector.data[start_pos:end_pos]] elif field_data.type == DataType.BINARY_VECTOR: dim = field_data.vectors.dim @@ -1174,3 +1175,12 @@ def select_grant_for_role_and_object(self, role_name, object, object_name, timeo raise MilvusException(resp.status.error_code, resp.status.reason) return GrantInfo(resp.entities) + + @retry_on_rpc_failure() + def get_server_version(self, timeout=None, **kwargs) -> str: + req = Prepare.get_server_version() + resp = self._stub.GetVersion(req, timeout=timeout) + if resp.status.error_code != 0: + raise MilvusException(resp.status.error_code, resp.status.reason) + + return resp.version diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 0adfd62e0..20b5d292d 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -800,3 +800,7 @@ def select_grant_request(cls, role_name, object, object_name): entity=milvus_types.GrantEntity(role=milvus_types.RoleEntity(name=role_name), object=milvus_types.ObjectEntity(name=object) if object else None, object_name=object_name if object_name else None)) + + @classmethod + def get_server_version(cls): + return milvus_types.GetVersionRequest() diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 1fcad03af..fdf3b18a8 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -419,12 +419,11 @@ def insert(self, data: [List, pandas.DataFrame], partition_name: str=None, timeo if data is None: return MutationResult(data) check_insert_data_schema(self._schema, data) + entities = Prepare.prepare_insert_data(data, self._schema) conn = self._get_connection() - entities = Prepare.prepare_insert_data(data, self._schema) - schema_dict = self._schema.to_dict() - schema_dict["consistency_level"] = self._consistency_level - res = conn.batch_insert(self._name, entities, partition_name, timeout=timeout, schema=schema_dict, **kwargs) + res = conn.batch_insert(self._name, entities, partition_name, + timeout=timeout, schema=self._schema_dict, **kwargs) if kwargs.get("_async", False): return MutationFuture(res) @@ -609,7 +608,7 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None conn = self._get_connection() res = conn.search(self._name, data, anns_field, param, limit, expr, partition_names, output_fields, round_decimal, timeout=timeout, - collection_schema=self._schema_dict, **kwargs) + schema=self._schema_dict, **kwargs) if kwargs.get("_async", False): return SearchFuture(res) return SearchResult(res) @@ -691,9 +690,8 @@ def query(self, expr, output_fields=None, partition_names=None, timeout=None, ** raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(expr)) conn = self._get_connection() - schema = self._schema.to_dict() - schema["consistency_level"] = self._consistency_level - res = conn.query(self._name, expr, output_fields, partition_names, timeout=timeout, schema=schema, **kwargs) + res = conn.query(self._name, expr, output_fields, partition_names, + timeout=timeout, schema=self._schema_dict, **kwargs) return res @property @@ -1083,6 +1081,6 @@ def get_replicas(self, timeout=None, **kwargs) -> Replica: conn = self._get_connection() return conn.get_replicas(self.name, timeout=timeout, **kwargs) - def describe(self): + def describe(self, timeout=None): conn = self._get_connection() - return conn.describe_collection(self.name) + return conn.describe_collection(self.name, timeout=timeout) diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py index 71c58a40b..73a3cff3e 100644 --- a/pymilvus/orm/partition.py +++ b/pymilvus/orm/partition.py @@ -49,6 +49,9 @@ def __init__(self, collection, name, description="", **kwargs): if not has: conn.create_partition(self._collection.name, self._name, **copy_kwargs) + self._schema_dict = self._schema.to_dict() + self._schema_dict["consistency_level"] = self._consistency_level + def __repr__(self): return json.dumps({ 'name': self.name, @@ -294,10 +297,8 @@ def insert(self, data, timeout=None, **kwargs): raise PartitionNotExistException(message=ExceptionsMessage.PartitionNotExist) # TODO: check insert data schema here? entities = Prepare.prepare_insert_data(data, self._collection.schema) - schema_dict = self._schema.to_dict() - schema_dict["consistency_level"] = self._consistency_level - res = conn.batch_insert(self._collection.name, entities=entities, - partition_name=self._name, timeout=timeout, orm=True, schema=schema_dict, **kwargs) + res = conn.batch_insert(self._collection.name, entities=entities, partition_name=self._name, + timeout=timeout, orm=True, schema=self._schema_dict, **kwargs) if kwargs.get("_async", False): return MutationFuture(res) return MutationResult(res) @@ -476,10 +477,8 @@ def search(self, data, anns_field, param, limit, - Top1 hit id: 8, distance: 0.10143111646175385, score: 0.10143111646175385 """ conn = self._get_connection() - schema_dict = self._schema.to_dict() - schema_dict["consistency_level"] = self._consistency_level res = conn.search(self._collection.name, data, anns_field, param, limit, expr, [self._name], output_fields, - round_decimal=round_decimal, timeout=timeout, collection_schema=schema_dict, **kwargs) + round_decimal=round_decimal, timeout=timeout, schema=self._schema_dict, **kwargs) if kwargs.get("_async", False): return SearchFuture(res) return SearchResult(res) @@ -552,10 +551,8 @@ def query(self, expr, output_fields=None, timeout=None, **kwargs): - Query results: [{'film_id': 0, 'film_date': 2000}, {'film_id': 1, 'film_date': 2001}] """ conn = self._get_connection() - schema_dict = self._schema.to_dict() - schema_dict["consistency_level"] = self._consistency_level res = conn.query(self._collection.name, expr, output_fields, [self._name], - timeout=timeout, schema=schema_dict, **kwargs) + timeout=timeout, schema=self._schema_dict, **kwargs) return res def get_replicas(self, timeout=None, **kwargs) -> Replica: diff --git a/pymilvus/orm/utility.py b/pymilvus/orm/utility.py index 543769fbf..8f7b10631 100644 --- a/pymilvus/orm/utility.py +++ b/pymilvus/orm/utility.py @@ -165,7 +165,7 @@ def _get_connection(alias): return connections._fetch_handler(alias) -def loading_progress(collection_name, partition_names=None, using="default"): +def loading_progress(collection_name, partition_names=None, using="default", timeout=None): """ Show loading progress of sealed segments in percentage. :param collection_name: The name of collection is loading @@ -198,7 +198,7 @@ def loading_progress(collection_name, partition_names=None, using="default"): >>> utility.loading_progress("test_loading_progress") {'loading_progress': '100%'} """ - progress = _get_connection(using).get_loading_progress(collection_name, partition_names) + progress = _get_connection(using).get_loading_progress(collection_name, partition_names, timeout=timeout) return { "loading_progress": f"{progress:.0f}%", } @@ -241,7 +241,7 @@ def wait_for_loading_complete(collection_name, partition_names=None, timeout=Non return _get_connection(using).wait_for_loading_partitions(collection_name, partition_names, timeout=timeout) -def index_building_progress(collection_name, index_name="", using="default"): +def index_building_progress(collection_name, index_name="", using="default", timeout=None): """ Show # indexed entities vs. # total entities. @@ -283,7 +283,8 @@ def index_building_progress(collection_name, index_name="", using="default"): >>> index = c.create_index(field_name="float_vector", index_params=index_params, index_name="ivf_flat") >>> utility.index_building_progress("test_collection", c.name) """ - return _get_connection(using).get_index_build_progress(collection_name=collection_name, index_name=index_name) + return _get_connection(using).get_index_build_progress( + collection_name=collection_name, index_name=index_name, timeout=timeout) def wait_for_index_building_complete(collection_name, index_name="", timeout=None, using="default"): @@ -332,7 +333,7 @@ def wait_for_index_building_complete(collection_name, index_name="", timeout=Non return _get_connection(using).wait_for_creating_index(collection_name, index_name, timeout=timeout)[0] -def has_collection(collection_name, using="default"): +def has_collection(collection_name, using="default", timeout=None): """ Checks whether a specified collection exists. @@ -352,10 +353,10 @@ def has_collection(collection_name, using="default"): >>> collection = Collection(name="test_collection", schema=schema) >>> utility.has_collection("test_collection") """ - return _get_connection(using).has_collection(collection_name) + return _get_connection(using).has_collection(collection_name, timeout=timeout) -def has_partition(collection_name, partition_name, using="default"): +def has_partition(collection_name, partition_name, using="default", timeout=None): """ Checks if a specified partition exists in a collection. @@ -378,7 +379,7 @@ def has_partition(collection_name, partition_name, using="default"): >>> collection = Collection(name="test_collection", schema=schema) >>> utility.has_partition("_default") """ - return _get_connection(using).has_partition(collection_name, partition_name) + return _get_connection(using).has_partition(collection_name, partition_name, timeout=timeout) def drop_collection(collection_name, timeout=None, using="default"): @@ -725,7 +726,7 @@ def list_bulk_insert_tasks(limit=0, collection_name=None, timeout=None, using="d return _get_connection(using).list_bulk_insert_tasks(limit, collection_name, timeout=timeout, **kwargs) -def reset_password(user: str, old_password: str, new_password: str, using="default"): +def reset_password(user: str, old_password: str, new_password: str, using="default", timeout=None): """ Reset the user & password of the connection. You must provide the original password to check if the operation is valid. @@ -745,10 +746,10 @@ def reset_password(user: str, old_password: str, new_password: str, using="defau >>> users = utility.list_usernames() >>> print(f"users in Milvus: {users}") """ - return _get_connection(using).reset_password(user, old_password, new_password) + return _get_connection(using).reset_password(user, old_password, new_password, timeout=timeout) -def create_user(user: str, password: str, using="default"): +def create_user(user: str, password: str, using="default", timeout=None): """ Create User using the given user and password. :param user: the user name. :type user: str @@ -763,10 +764,10 @@ def create_user(user: str, password: str, using="default"): >>> users = utility.list_usernames() >>> print(f"users in Milvus: {users}") """ - return _get_connection(using).create_user(user, password) + return _get_connection(using).create_user(user, password, timeout=timeout) -def update_password(user: str, old_password, new_password: str, using="default"): +def update_password(user: str, old_password, new_password: str, using="default", timeout=None): """ Update user password using the given user and password. You must provide the original password to check if the operation is valid. @@ -788,10 +789,10 @@ def update_password(user: str, old_password, new_password: str, using="default") >>> users = utility.list_usernames() >>> print(f"users in Milvus: {users}") """ - return _get_connection(using).update_password(user, old_password, new_password) + return _get_connection(using).update_password(user, old_password, new_password, timeout=timeout) -def delete_user(user: str, using="default"): +def delete_user(user: str, using="default", timeout=None): """ Delete User corresponding to the username. :param user: the user name. :type user: str @@ -803,10 +804,10 @@ def delete_user(user: str, using="default"): >>> users = utility.list_usernames() >>> print(f"users in Milvus: {users}") """ - return _get_connection(using).delete_user(user) + return _get_connection(using).delete_user(user, timeout=timeout) -def list_usernames(using="default"): +def list_usernames(using="default", timeout=None): """ List all usernames. :return list of str: The usernames in Milvus instances. @@ -817,10 +818,10 @@ def list_usernames(using="default"): >>> users = utility.list_usernames() >>> print(f"users in Milvus: {users}") """ - return _get_connection(using).list_usernames() + return _get_connection(using).list_usernames(timeout=timeout) -def list_roles(include_user_info: bool, using="default"): +def list_roles(include_user_info: bool, using="default", timeout=None): """ List All Role Info :param include_user_info: whether to obtain the user information associated with roles :type include_user_info: bool @@ -832,10 +833,10 @@ def list_roles(include_user_info: bool, using="default"): >>> roles = utility.list_roles() >>> print(f"roles in Milvus: {roles}") """ - return _get_connection(using).select_all_role(include_user_info) + return _get_connection(using).select_all_role(include_user_info, timeout=timeout) -def list_user(username: str, include_role_info: bool, using="default"): +def list_user(username: str, include_role_info: bool, using="default", timeout=None): """ List One User Info :param username: user name. :type username: str @@ -849,10 +850,10 @@ def list_user(username: str, include_role_info: bool, using="default"): >>> user = utility.list_user(username, include_role_info) >>> print(f"user info: {user}") """ - return _get_connection(using).select_one_user(username, include_role_info) + return _get_connection(using).select_one_user(username, include_role_info, timeout=timeout) -def list_users(include_role_info: bool, using="default"): +def list_users(include_role_info: bool, using="default", timeout=None): """ List All User Info :param include_role_info: whether to obtain the role information associated with users :type include_role_info: bool @@ -864,4 +865,18 @@ def list_users(include_role_info: bool, using="default"): >>> users = utility.list_users(include_role_info) >>> print(f"users info: {users}") """ - return _get_connection(using).select_all_user(include_role_info) + return _get_connection(using).select_all_user(include_role_info, timeout=timeout) + +def get_server_version(using="default", timeout=None) -> str: + """ get the running server's version + + :returns: server's version + :rtype: str + + :example: + >>> from pymilvus import connections, utility + >>> connections.connect() + >>> utility.get_server_version() + >>> "2.2.0" + """ + return _get_connection(using).get_server_version(timeout=timeout) diff --git a/tests/test_grpc_handler.py b/tests/test_grpc_handler.py index 233a5ccff..d19837a6c 100644 --- a/tests/test_grpc_handler.py +++ b/tests/test_grpc_handler.py @@ -70,3 +70,43 @@ def test_has_collection_Unavailable_exception(self, channel, client_thread): with pytest.raises(MilvusUnavailableException): has_collection_future.result() + + def test_get_server_version_error(self, channel, client_thread): + handler = GrpcHandler(channel=channel) + + get_version_future = client_thread.submit( + handler.get_server_version) + + (invocation_metadata, request, rpc) = ( + channel.take_unary_unary(descriptor.methods_by_name['GetVersion'])) + rpc.send_initial_metadata(()) + + expected_result = milvus_pb2.GetVersionResponse( + status=common_pb2.Status( + error_code=common_pb2.UnexpectedError, + reason="unexpected error"), + ) + rpc.terminate(expected_result, (), grpc.StatusCode.OK, '') + + with pytest.raises(MilvusException): + get_version_future.result() + + def test_get_server_version_error(self, channel, client_thread): + version = "2.2.0" + handler = GrpcHandler(channel=channel) + + get_version_future = client_thread.submit( + handler.get_server_version) + + (invocation_metadata, request, rpc) = ( + channel.take_unary_unary(descriptor.methods_by_name['GetVersion'])) + rpc.send_initial_metadata(()) + + expected_result = milvus_pb2.GetVersionResponse( + status=common_pb2.Status(error_code=common_pb2.Success), + version=version, + ) + rpc.terminate(expected_result, (), grpc.StatusCode.OK, '') + + got_result = get_version_future.result() + assert got_result == version