Skip to content

Commit

Permalink
refine search iterator perf by radius probe(#26552) (#1656)
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han authored Aug 23, 2023
1 parent 1c08c0d commit 71b52fc
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 70 deletions.
2 changes: 2 additions & 0 deletions pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@
MAX_FILTERED_IDS_COUNT_ITERATION = 100000
INT64_MAX = 9223372036854775807
MAX_BATCH_SIZE: int = 16384
DEFAULT_SEARCH_EXTENSION_RATE: int = 10
UNLIMITED: int = -1
MAX_TRY_TIME: int = 10
260 changes: 190 additions & 70 deletions pymilvus/orm/iterator.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, TypeVar

from pymilvus.client.abstract import ChunkedQueryResult, LoopBase
from pymilvus.client.abstract import LoopBase
from pymilvus.exceptions import (
MilvusException,
ParamError,
)
from pymilvus.orm.search import Hits

from .connections import Connections
from .constants import (
BATCH_SIZE,
CALC_DIST_COSINE,
CALC_DIST_HAMMING,
CALC_DIST_IP,
CALC_DIST_JACCARD,
CALC_DIST_L2,
CALC_DIST_TANIMOTO,
DEFAULT_MAX_HAMMING_DISTANCE,
DEFAULT_MAX_JACCARD_DISTANCE,
DEFAULT_MAX_L2_DISTANCE,
DEFAULT_MAX_TANIMOTO_DISTANCE,
DEFAULT_MIN_COSINE_DISTANCE,
DEFAULT_MIN_IP_DISTANCE,
DEFAULT_SEARCH_EXTENSION_RATE,
FIELDS,
INT64_MAX,
ITERATION_EXTENSION_REDUCE_RATE,
MAX_BATCH_SIZE,
MAX_FILTERED_IDS_COUNT_ITERATION,
MAX_TRY_TIME,
METRIC_TYPE,
MILVUS_LIMIT,
OFFSET,
Expand All @@ -42,6 +36,10 @@
SearchIterator = TypeVar("SearchIterator")


def extend_batch_size(batch_size: int) -> int:
return min(MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE)


class QueryIterator:
def __init__(
self,
Expand Down Expand Up @@ -143,12 +141,12 @@ def next(self):
self.__maybe_cache(res)
ret = res[0 : min(self._kwargs[BATCH_SIZE], len(res))]

ret = self.check_reached_limit(ret)
ret = self.__check_reached_limit(ret)
self.__update_cursor(ret)
self._returned_count += len(ret)
return ret

def check_reached_limit(self, ret: List):
def __check_reached_limit(self, ret: List):
if self._limit == UNLIMITED:
return ret
left_count = self._limit - self._returned_count
Expand Down Expand Up @@ -193,42 +191,71 @@ def close(self) -> None:
iterator_cache.release_cache(self._cache_id_in_use)


def default_radius(metrics: str):
def metrics_positive_related(metrics: str) -> bool:
if metrics is CALC_DIST_L2:
return DEFAULT_MAX_L2_DISTANCE
if metrics is CALC_DIST_IP:
return DEFAULT_MIN_IP_DISTANCE
if metrics is CALC_DIST_HAMMING:
return DEFAULT_MAX_HAMMING_DISTANCE
if metrics is CALC_DIST_TANIMOTO:
return DEFAULT_MAX_TANIMOTO_DISTANCE
if metrics is CALC_DIST_JACCARD:
return DEFAULT_MAX_JACCARD_DISTANCE
if metrics is CALC_DIST_COSINE:
return DEFAULT_MIN_COSINE_DISTANCE
raise MilvusException(message="unknown metrics type for search iteration")
return True
if metrics is CALC_DIST_IP or metrics is CALC_DIST_COSINE:
return False
raise MilvusException(message=f"unsupported metrics type for search iteration{metrics}")


class SearchHit:
def __init__(self, id: Any, distance: Any):
self._id = id
self._distance = distance


"""
class SearchHits:
def __init__(self, ids: list, distances: list):
self._ids = ids
self._distances = distances
def __next__(self):
return Hit()
"""


class SearchPage(LoopBase):
"""Since we only support nq=1 in search iteration, so search iteration response
should be different from raw response of search operation"""

def __init__(self, res: List):
def __init__(self, res: Hits):
super().__init__()
self._res = res
self._results = []
if res is not None:
self._results.append(res)

def get_res(self):
return self._res
return self._results

def __len__(self):
if self._res is not None:
return len(self._res)
return 0
length = 0
for res in self._results:
length += len(res)
return length

def get__item(self, idx: Any):
if self._res is None:
if len(self._results) == 0:
return None
return self._res[idx]
if idx >= self.__len__():
msg = "Index out of range"
raise IndexError(msg)
index = 0
ret = None
for res in self._results:
if index + len(res) <= idx:
index += len(res)
else:
ret = res[idx - index]
break
return ret

def merge(self, others: List[Hits]):
if others is not None:
for other in others:
self._results.append(other)


class SearchIterator:
Expand Down Expand Up @@ -269,27 +296,39 @@ def __init__(
self._expr = expr
self.__check_set_params(param)
self._kwargs = kwargs
self._distance_cursor = [None]
self._filtered_ids = []
self._filtered_distance = None
self._schema = schema
self._limit = limit
self._returned_count = 0
self.__check_metrics()
self.__check_radius()
self.__seek()
self.__check_offset()
self.__check_rm_range_search_parameters()
self.__setup__pk_prop()
self.__init_search_iterator()

def check_reached_limit(self, ret: ChunkedQueryResult):
if self._limit == UNLIMITED:
return SearchPage(ret[0])
left_count = self._limit - self._returned_count
if left_count >= len(ret[0]):
return SearchPage(ret[0])
# has exceeded the limit, cut off the result and return
left_ret_arr = None
left_ret_arr = [] if left_count == 0 else ret[0][0:left_count]
return SearchPage(left_ret_arr)
def __init_search_iterator(self):
init_page = self.__execute_next_search(self._param, self._expr)
if len(init_page) == 0:
raise MilvusException(
message="Cannot init search iterator because there's no matched vectors returned"
)
self._cache_id = iterator_cache.cache(init_page, NO_CACHE_ID)
self.__set_up_range_parameters(init_page)
self.__update_filtered_ids(init_page)

def __set_up_range_parameters(self, page: SearchPage):
first_hit, last_hit = page[0], page[-1]
if metrics_positive_related(self._param[METRIC_TYPE]):
self._width = last_hit.distance - first_hit.distance
else:
self._width = first_hit.distance - last_hit.distance
self._tail_band = last_hit.distance

def __check_reached_limit(self) -> bool:
if self._limit == UNLIMITED or self._returned_count < self._limit:
return False
return True

def __check_set_params(self, param: Dict):
if param is None:
Expand All @@ -308,32 +347,34 @@ def __setup__pk_prop(self):
self._pk_field_name = field["name"]
break
if self._pk_field_name is None or self._pk_field_name == "":
raise MilvusException(message="schema must contain pk field, broke")
raise ParamError(message="schema must contain pk field, broke")

def __check_metrics(self):
if self._param[METRIC_TYPE] is None or self._param[METRIC_TYPE] == "":
raise MilvusException(message="must specify metrics type for search iterator")
raise ParamError(message="must specify metrics type for search iterator")

def __check_radius(self):
if PARAMS not in self._param:
self._param[PARAMS] = {"radius": default_radius(self._param[METRIC_TYPE])}
elif RADIUS not in self._param[PARAMS]:
self._param[PARAMS][RADIUS] = default_radius(self._param[METRIC_TYPE])
"""we use search && range search to implement search iterator,
so range search parameters are disabled to clients"""

def __seek(self):
def __check_rm_range_search_parameters(self):
if PARAMS in self._param and RADIUS in self._param[PARAMS]:
del self._param[PARAMS][RADIUS]
if PARAMS in self._param and RANGE_FILTER in self._param[PARAMS]:
del self._param[PARAMS][RANGE_FILTER]

def __check_offset(self):
if self._kwargs.get(OFFSET, 0) != 0:
raise MilvusException(message="Not support offset when searching iteration")
raise ParamError(message="Not support offset when searching iteration")

def __update_cursor(self, res: SearchPage):
def __update_filtered_ids(self, res: SearchPage):
if len(res) == 0:
return
last_hit = res[-1]
if last_hit is None:
return
self._distance_cursor[0] = last_hit.distance
if self._distance_cursor[0] != self._filtered_distance:
if last_hit.distance != self._filtered_distance:
self._filtered_ids = [] # distance has changed, clear filter_ids array
self._filtered_distance = self._distance_cursor[0] # renew the distance for filtering
self._filtered_distance = last_hit.distance # renew the distance for filtering
for hit in res:
if hit.distance == last_hit.distance:
self._filtered_ids.append(hit.id)
Expand All @@ -344,15 +385,93 @@ def __update_cursor(self, res: SearchPage):
f"there is a danger of overly memory consumption"
)

def __is_cache_enough(self, count: int) -> bool:
cached_page = iterator_cache.fetch_cache(self._cache_id)
if cached_page is None or len(cached_page) < count:
return False
return True

def __extract_page_from_cache(self, count: int) -> SearchPage:
cached_page = iterator_cache.fetch_cache(self._cache_id)
if cached_page is None or len(cached_page) < count:
raise ParamError(
message=f"Wrong, try to extract {count} result from cache, "
f"more than {len(cached_page)} there must be sth wrong with code"
)

ret_page_res = cached_page[0:count]
ret_page = SearchPage(ret_page_res)
left_cache_page = SearchPage(cached_page[count:])
iterator_cache.cache(left_cache_page, self._cache_id)
return ret_page

def __push_new_page_to_cache(self, page: SearchPage) -> int:
if page is None:
raise ParamError(message="Cannot push None page into cache")
cached_page: SearchPage = iterator_cache.fetch_cache(self._cache_id)
if cached_page is None:
iterator_cache.cache(page, self._cache_id)
cached_page = page
else:
cached_page.merge(page.get_res())
return len(cached_page)

def next(self):
next_params = self.__next_params()
next_expr = self.__filtered_duplicated_result_expr(self._expr)
# 0. check reached limit
if self.__check_reached_limit():
return SearchPage(None)
ret_len = self._iterator_params[BATCH_SIZE]
if self._limit is not UNLIMITED:
left_len = self._limit - self._returned_count
ret_len = min(left_len, ret_len)

# 1. if cached page is sufficient, directly return
if self.__is_cache_enough(ret_len):
ret_page = self.__extract_page_from_cache(ret_len)
self._returned_count += len(ret_page)
return ret_page

# 2. if cached page not enough, try to fill the result by probing with constant width
# until finish filling or exceed max trial time: 10
new_page = self.__try_search_fill()
cached_page_len = self.__push_new_page_to_cache(new_page)
ret_len = min(cached_page_len, ret_len)
ret_page = self.__extract_page_from_cache(ret_len)

# 3. update filter ids to avoid returning result repeatedly
self._returned_count += ret_len
return ret_page

def __try_search_fill(self) -> SearchPage:
final_page = SearchPage(None)
try_time = 0
coefficient = 1
while True:
next_params = self.__next_params(coefficient)
next_expr = self.__filtered_duplicated_result_expr(self._expr)
new_page = self.__execute_next_search(next_params, next_expr)
self.__update_filtered_ids(new_page)
try_time += 1
if len(new_page) > 0:
final_page.merge(new_page.get_res())
self._tail_band = new_page[-1].distance
# if the current ring contains vectors, we always set coefficient back to 1
coefficient = 1
else:
# if there's a ring containing no vectors matched, then we need to extend
# the ring continually to avoid empty ring problem
coefficient += 1
if len(final_page) > self._iterator_params[BATCH_SIZE] or try_time > MAX_TRY_TIME:
break
return final_page

def __execute_next_search(self, next_params: dict, next_expr: str) -> SearchPage:
res = self._conn.search(
self._iterator_params["collection_name"],
self._iterator_params["data"],
self._iterator_params["ann_field"],
next_params,
self._iterator_params[BATCH_SIZE],
extend_batch_size(self._iterator_params[BATCH_SIZE]),
next_expr,
self._iterator_params["partition_names"],
self._iterator_params["output_fields"],
Expand All @@ -361,10 +480,7 @@ def next(self):
schema=self._schema,
**self._kwargs,
)
res = self.check_reached_limit(res)
self.__update_cursor(res)
self._returned_count += len(res)
return res
return SearchPage(res[0])

# at present, the range_filter parameter means 'larger/less and equal',
# so there would be vectors with same distances returned multiple times in different pages
Expand All @@ -388,14 +504,18 @@ def __filtered_duplicated_result_expr(self, expr: str):
return f"{self._pk_field_name} not in [{filtered_ids_str}]"
return expr

def __next_params(self):
def __next_params(self, coefficient: int):
coefficient = max(1, coefficient)
next_params = self._param.copy()
if self._distance_cursor[0] is not None:
next_params[PARAMS][RANGE_FILTER] = self._distance_cursor[0]
if metrics_positive_related(self._param[METRIC_TYPE]):
next_params[PARAMS][RADIUS] = self._tail_band + self._width * coefficient
else:
next_params[PARAMS][RADIUS] = self._tail_band - self._width * coefficient
next_params[PARAMS][RANGE_FILTER] = self._tail_band
return next_params

def close(self):
pass
iterator_cache.release_cache(self._cache_id)


class IteratorCache:
Expand Down

0 comments on commit 71b52fc

Please sign in to comment.