diff --git a/pymilvus/orm/constants.py b/pymilvus/orm/constants.py index 8d87f39d0..df50856a8 100644 --- a/pymilvus/orm/constants.py +++ b/pymilvus/orm/constants.py @@ -45,4 +45,6 @@ MAX_FILTERED_IDS_COUNT_ITERATION = 100000 INT64_MAX = 9223372036854775807 MAX_BATCH_SIZE: int = 16384 +DEFAULT_SEARCH_EXTENSION_RATE: int = 10 UNLIMITED: int = -1 +MAX_TRY_TIME: int = 10 diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 3d91c9b42..7b55b4610 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -1,32 +1,26 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, TypeVar -from pymilvus.client.abstract import ChunkedQueryResult, LoopBase +from pymilvus.client.abstract import LoopBase from pymilvus.exceptions import ( MilvusException, ParamError, ) +from pymilvus.orm.search import Hits from .connections import Connections from .constants import ( BATCH_SIZE, CALC_DIST_COSINE, - CALC_DIST_HAMMING, CALC_DIST_IP, - CALC_DIST_JACCARD, CALC_DIST_L2, - CALC_DIST_TANIMOTO, - DEFAULT_MAX_HAMMING_DISTANCE, - DEFAULT_MAX_JACCARD_DISTANCE, - DEFAULT_MAX_L2_DISTANCE, - DEFAULT_MAX_TANIMOTO_DISTANCE, - DEFAULT_MIN_COSINE_DISTANCE, - DEFAULT_MIN_IP_DISTANCE, + DEFAULT_SEARCH_EXTENSION_RATE, FIELDS, INT64_MAX, ITERATION_EXTENSION_REDUCE_RATE, MAX_BATCH_SIZE, MAX_FILTERED_IDS_COUNT_ITERATION, + MAX_TRY_TIME, METRIC_TYPE, MILVUS_LIMIT, OFFSET, @@ -42,6 +36,10 @@ SearchIterator = TypeVar("SearchIterator") +def extend_batch_size(batch_size: int) -> int: + return min(MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE) + + class QueryIterator: def __init__( self, @@ -143,12 +141,12 @@ def next(self): self.__maybe_cache(res) ret = res[0 : min(self._kwargs[BATCH_SIZE], len(res))] - ret = self.check_reached_limit(ret) + ret = self.__check_reached_limit(ret) self.__update_cursor(ret) self._returned_count += len(ret) return ret - def check_reached_limit(self, ret: List): + def __check_reached_limit(self, ret: List): if self._limit == UNLIMITED: return ret left_count = self._limit - self._returned_count @@ -193,42 +191,71 @@ def close(self) -> None: iterator_cache.release_cache(self._cache_id_in_use) -def default_radius(metrics: str): +def metrics_positive_related(metrics: str) -> bool: if metrics is CALC_DIST_L2: - return DEFAULT_MAX_L2_DISTANCE - if metrics is CALC_DIST_IP: - return DEFAULT_MIN_IP_DISTANCE - if metrics is CALC_DIST_HAMMING: - return DEFAULT_MAX_HAMMING_DISTANCE - if metrics is CALC_DIST_TANIMOTO: - return DEFAULT_MAX_TANIMOTO_DISTANCE - if metrics is CALC_DIST_JACCARD: - return DEFAULT_MAX_JACCARD_DISTANCE - if metrics is CALC_DIST_COSINE: - return DEFAULT_MIN_COSINE_DISTANCE - raise MilvusException(message="unknown metrics type for search iteration") + return True + if metrics is CALC_DIST_IP or metrics is CALC_DIST_COSINE: + return False + raise MilvusException(message=f"unsupported metrics type for search iteration{metrics}") + + +class SearchHit: + def __init__(self, id: Any, distance: Any): + self._id = id + self._distance = distance + + +""" +class SearchHits: + + def __init__(self, ids: list, distances: list): + self._ids = ids + self._distances = distances + + def __next__(self): + return Hit() +""" class SearchPage(LoopBase): """Since we only support nq=1 in search iteration, so search iteration response should be different from raw response of search operation""" - def __init__(self, res: List): + def __init__(self, res: Hits): super().__init__() - self._res = res + self._results = [] + if res is not None: + self._results.append(res) def get_res(self): - return self._res + return self._results def __len__(self): - if self._res is not None: - return len(self._res) - return 0 + length = 0 + for res in self._results: + length += len(res) + return length def get__item(self, idx: Any): - if self._res is None: + if len(self._results) == 0: return None - return self._res[idx] + if idx >= self.__len__(): + msg = "Index out of range" + raise IndexError(msg) + index = 0 + ret = None + for res in self._results: + if index + len(res) <= idx: + index += len(res) + else: + ret = res[idx - index] + break + return ret + + def merge(self, others: List[Hits]): + if others is not None: + for other in others: + self._results.append(other) class SearchIterator: @@ -269,27 +296,39 @@ def __init__( self._expr = expr self.__check_set_params(param) self._kwargs = kwargs - self._distance_cursor = [None] self._filtered_ids = [] self._filtered_distance = None self._schema = schema self._limit = limit self._returned_count = 0 self.__check_metrics() - self.__check_radius() - self.__seek() + self.__check_offset() + self.__check_rm_range_search_parameters() self.__setup__pk_prop() + self.__init_search_iterator() - def check_reached_limit(self, ret: ChunkedQueryResult): - if self._limit == UNLIMITED: - return SearchPage(ret[0]) - left_count = self._limit - self._returned_count - if left_count >= len(ret[0]): - return SearchPage(ret[0]) - # has exceeded the limit, cut off the result and return - left_ret_arr = None - left_ret_arr = [] if left_count == 0 else ret[0][0:left_count] - return SearchPage(left_ret_arr) + def __init_search_iterator(self): + init_page = self.__execute_next_search(self._param, self._expr) + if len(init_page) == 0: + raise MilvusException( + message="Cannot init search iterator because there's no matched vectors returned" + ) + self._cache_id = iterator_cache.cache(init_page, NO_CACHE_ID) + self.__set_up_range_parameters(init_page) + self.__update_filtered_ids(init_page) + + def __set_up_range_parameters(self, page: SearchPage): + first_hit, last_hit = page[0], page[-1] + if metrics_positive_related(self._param[METRIC_TYPE]): + self._width = last_hit.distance - first_hit.distance + else: + self._width = first_hit.distance - last_hit.distance + self._tail_band = last_hit.distance + + def __check_reached_limit(self) -> bool: + if self._limit == UNLIMITED or self._returned_count < self._limit: + return False + return True def __check_set_params(self, param: Dict): if param is None: @@ -308,32 +347,34 @@ def __setup__pk_prop(self): self._pk_field_name = field["name"] break if self._pk_field_name is None or self._pk_field_name == "": - raise MilvusException(message="schema must contain pk field, broke") + raise ParamError(message="schema must contain pk field, broke") def __check_metrics(self): if self._param[METRIC_TYPE] is None or self._param[METRIC_TYPE] == "": - raise MilvusException(message="must specify metrics type for search iterator") + raise ParamError(message="must specify metrics type for search iterator") - def __check_radius(self): - if PARAMS not in self._param: - self._param[PARAMS] = {"radius": default_radius(self._param[METRIC_TYPE])} - elif RADIUS not in self._param[PARAMS]: - self._param[PARAMS][RADIUS] = default_radius(self._param[METRIC_TYPE]) + """we use search && range search to implement search iterator, + so range search parameters are disabled to clients""" - def __seek(self): + def __check_rm_range_search_parameters(self): + if PARAMS in self._param and RADIUS in self._param[PARAMS]: + del self._param[PARAMS][RADIUS] + if PARAMS in self._param and RANGE_FILTER in self._param[PARAMS]: + del self._param[PARAMS][RANGE_FILTER] + + def __check_offset(self): if self._kwargs.get(OFFSET, 0) != 0: - raise MilvusException(message="Not support offset when searching iteration") + raise ParamError(message="Not support offset when searching iteration") - def __update_cursor(self, res: SearchPage): + def __update_filtered_ids(self, res: SearchPage): if len(res) == 0: return last_hit = res[-1] if last_hit is None: return - self._distance_cursor[0] = last_hit.distance - if self._distance_cursor[0] != self._filtered_distance: + if last_hit.distance != self._filtered_distance: self._filtered_ids = [] # distance has changed, clear filter_ids array - self._filtered_distance = self._distance_cursor[0] # renew the distance for filtering + self._filtered_distance = last_hit.distance # renew the distance for filtering for hit in res: if hit.distance == last_hit.distance: self._filtered_ids.append(hit.id) @@ -344,15 +385,93 @@ def __update_cursor(self, res: SearchPage): f"there is a danger of overly memory consumption" ) + def __is_cache_enough(self, count: int) -> bool: + cached_page = iterator_cache.fetch_cache(self._cache_id) + if cached_page is None or len(cached_page) < count: + return False + return True + + def __extract_page_from_cache(self, count: int) -> SearchPage: + cached_page = iterator_cache.fetch_cache(self._cache_id) + if cached_page is None or len(cached_page) < count: + raise ParamError( + message=f"Wrong, try to extract {count} result from cache, " + f"more than {len(cached_page)} there must be sth wrong with code" + ) + + ret_page_res = cached_page[0:count] + ret_page = SearchPage(ret_page_res) + left_cache_page = SearchPage(cached_page[count:]) + iterator_cache.cache(left_cache_page, self._cache_id) + return ret_page + + def __push_new_page_to_cache(self, page: SearchPage) -> int: + if page is None: + raise ParamError(message="Cannot push None page into cache") + cached_page: SearchPage = iterator_cache.fetch_cache(self._cache_id) + if cached_page is None: + iterator_cache.cache(page, self._cache_id) + cached_page = page + else: + cached_page.merge(page.get_res()) + return len(cached_page) + def next(self): - next_params = self.__next_params() - next_expr = self.__filtered_duplicated_result_expr(self._expr) + # 0. check reached limit + if self.__check_reached_limit(): + return SearchPage(None) + ret_len = self._iterator_params[BATCH_SIZE] + if self._limit is not UNLIMITED: + left_len = self._limit - self._returned_count + ret_len = min(left_len, ret_len) + + # 1. if cached page is sufficient, directly return + if self.__is_cache_enough(ret_len): + ret_page = self.__extract_page_from_cache(ret_len) + self._returned_count += len(ret_page) + return ret_page + + # 2. if cached page not enough, try to fill the result by probing with constant width + # until finish filling or exceed max trial time: 10 + new_page = self.__try_search_fill() + cached_page_len = self.__push_new_page_to_cache(new_page) + ret_len = min(cached_page_len, ret_len) + ret_page = self.__extract_page_from_cache(ret_len) + + # 3. update filter ids to avoid returning result repeatedly + self._returned_count += ret_len + return ret_page + + def __try_search_fill(self) -> SearchPage: + final_page = SearchPage(None) + try_time = 0 + coefficient = 1 + while True: + next_params = self.__next_params(coefficient) + next_expr = self.__filtered_duplicated_result_expr(self._expr) + new_page = self.__execute_next_search(next_params, next_expr) + self.__update_filtered_ids(new_page) + try_time += 1 + if len(new_page) > 0: + final_page.merge(new_page.get_res()) + self._tail_band = new_page[-1].distance + # if the current ring contains vectors, we always set coefficient back to 1 + coefficient = 1 + else: + # if there's a ring containing no vectors matched, then we need to extend + # the ring continually to avoid empty ring problem + coefficient += 1 + if len(final_page) > self._iterator_params[BATCH_SIZE] or try_time > MAX_TRY_TIME: + break + return final_page + + def __execute_next_search(self, next_params: dict, next_expr: str) -> SearchPage: res = self._conn.search( self._iterator_params["collection_name"], self._iterator_params["data"], self._iterator_params["ann_field"], next_params, - self._iterator_params[BATCH_SIZE], + extend_batch_size(self._iterator_params[BATCH_SIZE]), next_expr, self._iterator_params["partition_names"], self._iterator_params["output_fields"], @@ -361,10 +480,7 @@ def next(self): schema=self._schema, **self._kwargs, ) - res = self.check_reached_limit(res) - self.__update_cursor(res) - self._returned_count += len(res) - return res + return SearchPage(res[0]) # at present, the range_filter parameter means 'larger/less and equal', # so there would be vectors with same distances returned multiple times in different pages @@ -388,14 +504,18 @@ def __filtered_duplicated_result_expr(self, expr: str): return f"{self._pk_field_name} not in [{filtered_ids_str}]" return expr - def __next_params(self): + def __next_params(self, coefficient: int): + coefficient = max(1, coefficient) next_params = self._param.copy() - if self._distance_cursor[0] is not None: - next_params[PARAMS][RANGE_FILTER] = self._distance_cursor[0] + if metrics_positive_related(self._param[METRIC_TYPE]): + next_params[PARAMS][RADIUS] = self._tail_band + self._width * coefficient + else: + next_params[PARAMS][RADIUS] = self._tail_band - self._width * coefficient + next_params[PARAMS][RANGE_FILTER] = self._tail_band return next_params def close(self): - pass + iterator_cache.release_cache(self._cache_id) class IteratorCache: