diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index df50856a8..e85dcfc3a 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -36,6 +36,7 @@ RANGE_FILTER = "range_filter" FIELDS = "fields" ITERATION_EXTENSION_REDUCE_RATE = "iteration_extension_reduce_rate" +EF = "ef" DEFAULT_MAX_L2_DISTANCE = 99999999.0 DEFAULT_MIN_IP_DISTANCE = -99999999.0 DEFAULT_MAX_HAMMING_DISTANCE = 99999999.0 diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 3671dd150..4c83424e7 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -18,6 +18,7 @@ CALC_DIST_L2, CALC_DIST_TANIMOTO, DEFAULT_SEARCH_EXTENSION_RATE, + EF, FIELDS, INT64_MAX, ITERATION_EXTENSION_REDUCE_RATE, @@ -39,7 +40,11 @@ SearchIterator = TypeVar("SearchIterator") -def extend_batch_size(batch_size: int) -> int: +def extend_batch_size(batch_size: int, next_param: dict) -> int: + if EF in next_param[PARAMS]: + return min( + MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE, next_param[PARAMS][EF] + ) return min(MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE) @@ -294,6 +299,7 @@ def __init__( } self._expr = expr self.__check_set_params(param) + self.__check_for_special_index_param() self._kwargs = kwargs self._filtered_ids = [] self._filtered_distance = None @@ -337,6 +343,15 @@ def __check_set_params(self, param: Dict): if PARAMS not in self._param: self._param[PARAMS] = {} + def __check_for_special_index_param(self): + if ( + EF in self._param[PARAMS] + and self._param[PARAMS][EF] < self._iterator_params[BATCH_SIZE] + ): + raise MilvusException( + message="When using hnsw index, provided ef must be larger than or equal to batch size" + ) + def __setup__pk_prop(self): fields = self._schema[FIELDS] for field in fields: @@ -472,7 +487,7 @@ def __execute_next_search(self, next_params: dict, next_expr: str) -> SearchPage self._iterator_params["data"], self._iterator_params["ann_field"], next_params, - extend_batch_size(self._iterator_params[BATCH_SIZE]), + extend_batch_size(self._iterator_params[BATCH_SIZE], next_params), next_expr, self._iterator_params["partition_names"], self._iterator_params["output_fields"],