Skip to content

Commit

Permalink
support mvcc and break-down-continue for iterator(#2278) (#2279)
Browse files Browse the repository at this point in the history
related: #2278

Signed-off-by: MrPresent-Han <[email protected]>
Co-authored-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han and MrPresent-Han authored Oct 9, 2024
1 parent 7784050 commit f068b1a
Show file tree
Hide file tree
Showing 12 changed files with 440 additions and 325 deletions.
17 changes: 9 additions & 8 deletions examples/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@
logging.getLogger().addHandler(console_handler) # Attach the handler to the root logger



def re_create_collection(skip_data_period: bool):
if not skip_data_period:
def re_create_collection(prepare_new_data: bool):
if prepare_new_data:
if utility.has_collection(COLLECTION_NAME) and CLEAR_EXIST:
utility.drop_collection(COLLECTION_NAME)
print(f"dropped existed collection{COLLECTION_NAME}")
Expand Down Expand Up @@ -118,7 +117,8 @@ def query_iterate_collection_no_offset(collection):

query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
offset=0, batch_size=5, consistency_level=CONSISTENCY_LEVEL,
reduce_stop_for_best="false", print_iterator_cursor=True)
reduce_stop_for_best="false", print_iterator_cursor=False,
iterator_cp_file="/tmp/it_cp")
no_best_ids: set = set({})
page_idx = 0
while True:
Expand All @@ -136,7 +136,8 @@ def query_iterate_collection_no_offset(collection):
print("best---------------------------")
query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
offset=0, batch_size=5, consistency_level=CONSISTENCY_LEVEL,
reduce_stop_for_best="true", print_iterator_cursor=True)
reduce_stop_for_best="true", print_iterator_cursor=False, iterator_cp_file="/tmp/it_cp")

best_ids: set = set({})
page_idx = 0
while True:
Expand Down Expand Up @@ -239,10 +240,10 @@ def search_iterator_collection_with_limit(collection):


def main():
skip_data_period = True
prepare_new_data = True
connections.connect("default", host=HOST, port=PORT)
collection = re_create_collection(skip_data_period)
if not skip_data_period:
collection = re_create_collection(prepare_new_data)
if prepare_new_data:
collection = prepare_data(collection)
query_iterate_collection_no_offset(collection)
query_iterate_collection_with_offset(collection)
Expand Down
6 changes: 5 additions & 1 deletion pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ def __init__(
res: schema_pb2.SearchResultData,
round_decimal: Optional[int] = None,
status: Optional[common_pb2.Status] = None,
session_ts: Optional[int] = 0,
):
self._nq = res.num_queries
all_topks = res.topks
Expand Down Expand Up @@ -441,9 +442,12 @@ def __init__(
Hits(topk, all_pks[start:end], all_scores[start:end], nq_th_fields, output_fields)
)
nq_thres += topk

self._session_ts = session_ts
super().__init__(data)

def get_session_ts(self):
return self._session_ts

def get_fields_by_range(
self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData]
) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]:
Expand Down
3 changes: 3 additions & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
RANK_GROUP_SCORER = "rank_group_scorer"
GROUP_STRICT_SIZE = "group_strict_size"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
PAGE_RETAIN_ORDER_FIELD = "page_retain_order"

RANKER_TYPE_RRF = "rrf"
RANKER_TYPE_WEIGHTED = "weighted"

GUARANTEE_TIMESTAMP = "guarantee_timestamp"
14 changes: 11 additions & 3 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
is_legal_host,
is_legal_port,
)
from .constants import ITERATOR_SESSION_TS_FIELD
from .prepare import Prepare
from .types import (
BulkInsertState,
Expand Down Expand Up @@ -733,8 +734,12 @@ def _execute_search(
response = self._stub.Search(request, timeout=timeout)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal, status=response.status)

return SearchResult(
response.results,
round_decimal,
status=response.status,
session_ts=response.session_ts,
)
except Exception as e:
if kwargs.get("_async", False):
return SearchFuture(None, None, e)
Expand Down Expand Up @@ -1554,7 +1559,10 @@ def query(
response.fields_data, index, dynamic_fields
)
results.append(entity_row_data)
return ExtraList(results, extra=get_cost_extra(response.status))

extra_dict = get_cost_extra(response.status)
extra_dict[ITERATOR_SESSION_TS_FIELD] = response.session_ts
return ExtraList(results, extra=extra_dict)

@retry_on_rpc_failure()
def load_balance(
Expand Down
15 changes: 9 additions & 6 deletions pymilvus/client/ts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pymilvus.grpc_gen import common_pb2

from .constants import BOUNDED_TS, EVENTUALLY_TS
from .constants import BOUNDED_TS, EVENTUALLY_TS, GUARANTEE_TIMESTAMP, ITERATOR_FIELD
from .singleton_utils import Singleton
from .types import get_consistency_level
from .utils import hybridts_to_unixtime
Expand Down Expand Up @@ -75,26 +75,29 @@ def get_bounded_ts():


def construct_guarantee_ts(collection_name: str, kwargs: Dict):
if kwargs.get(ITERATOR_FIELD) is not None:
return True

consistency_level = kwargs.get("consistency_level")
use_default = consistency_level is None
if use_default:
# in case of the default consistency is Customized or Session,
# we set guarantee_timestamp to the cached mutation ts or 1
kwargs["guarantee_timestamp"] = get_collection_ts(collection_name) or get_eventually_ts()
kwargs[GUARANTEE_TIMESTAMP] = get_collection_ts(collection_name) or get_eventually_ts()
return True
consistency_level = get_consistency_level(consistency_level)
kwargs["consistency_level"] = consistency_level
if consistency_level == ConsistencyLevel.Strong:
# Milvus will assign a newest ts.
kwargs["guarantee_timestamp"] = 0
kwargs[GUARANTEE_TIMESTAMP] = 0
elif consistency_level == ConsistencyLevel.Session:
# Using the last write ts of the collection.
# TODO: get a timestamp from server?
kwargs["guarantee_timestamp"] = get_collection_ts(collection_name) or get_eventually_ts()
kwargs[GUARANTEE_TIMESTAMP] = get_collection_ts(collection_name) or get_eventually_ts()
elif consistency_level == ConsistencyLevel.Bounded:
# Milvus will assign ts according to the server timestamp and a configured time interval
kwargs["guarantee_timestamp"] = get_bounded_ts()
kwargs[GUARANTEE_TIMESTAMP] = get_bounded_ts()
else:
# Users customize the consistency level, no modification on `guarantee_timestamp`.
kwargs.setdefault("guarantee_timestamp", get_eventually_ts())
kwargs.setdefault(GUARANTEE_TIMESTAMP, get_eventually_ts())
return use_default
2 changes: 1 addition & 1 deletion pymilvus/grpc_gen/milvus-proto
528 changes: 264 additions & 264 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions pymilvus/grpc_gen/milvus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -791,14 +791,16 @@ class Hits(_message.Message):
def __init__(self, IDs: _Optional[_Iterable[int]] = ..., row_data: _Optional[_Iterable[bytes]] = ..., scores: _Optional[_Iterable[float]] = ...) -> None: ...

class SearchResults(_message.Message):
__slots__ = ("status", "results", "collection_name")
__slots__ = ("status", "results", "collection_name", "session_ts")
STATUS_FIELD_NUMBER: _ClassVar[int]
RESULTS_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
SESSION_TS_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
results: _schema_pb2.SearchResultData
collection_name: str
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., results: _Optional[_Union[_schema_pb2.SearchResultData, _Mapping]] = ..., collection_name: _Optional[str] = ...) -> None: ...
session_ts: int
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., results: _Optional[_Union[_schema_pb2.SearchResultData, _Mapping]] = ..., collection_name: _Optional[str] = ..., session_ts: _Optional[int] = ...) -> None: ...

class HybridSearchRequest(_message.Message):
__slots__ = ("base", "db_name", "collection_name", "partition_names", "requests", "rank_params", "travel_timestamp", "guarantee_timestamp", "not_return_all_meta", "output_fields", "consistency_level", "use_default_consistency")
Expand Down Expand Up @@ -920,16 +922,18 @@ class QueryRequest(_message.Message):
def __init__(self, base: _Optional[_Union[_common_pb2.MsgBase, _Mapping]] = ..., db_name: _Optional[str] = ..., collection_name: _Optional[str] = ..., expr: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ..., partition_names: _Optional[_Iterable[str]] = ..., travel_timestamp: _Optional[int] = ..., guarantee_timestamp: _Optional[int] = ..., query_params: _Optional[_Iterable[_Union[_common_pb2.KeyValuePair, _Mapping]]] = ..., not_return_all_meta: bool = ..., consistency_level: _Optional[_Union[_common_pb2.ConsistencyLevel, str]] = ..., use_default_consistency: bool = ...) -> None: ...

class QueryResults(_message.Message):
__slots__ = ("status", "fields_data", "collection_name", "output_fields")
__slots__ = ("status", "fields_data", "collection_name", "output_fields", "session_ts")
STATUS_FIELD_NUMBER: _ClassVar[int]
FIELDS_DATA_FIELD_NUMBER: _ClassVar[int]
COLLECTION_NAME_FIELD_NUMBER: _ClassVar[int]
OUTPUT_FIELDS_FIELD_NUMBER: _ClassVar[int]
SESSION_TS_FIELD_NUMBER: _ClassVar[int]
status: _common_pb2.Status
fields_data: _containers.RepeatedCompositeFieldContainer[_schema_pb2.FieldData]
collection_name: str
output_fields: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., collection_name: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ...) -> None: ...
session_ts: int
def __init__(self, status: _Optional[_Union[_common_pb2.Status, _Mapping]] = ..., fields_data: _Optional[_Iterable[_Union[_schema_pb2.FieldData, _Mapping]]] = ..., collection_name: _Optional[str] = ..., output_fields: _Optional[_Iterable[str]] = ..., session_ts: _Optional[int] = ...) -> None: ...

class VectorIDs(_message.Message):
__slots__ = ("collection_name", "field_name", "id_array", "partition_names")
Expand Down
Loading

0 comments on commit f068b1a

Please sign in to comment.