Skip to content

Commit

Permalink
Merge branch 'main' into switch-from-black-to-ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
jamshale authored Jul 4, 2024
2 parents f1ac923 + d8490b9 commit 14a6301
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 15 deletions.
46 changes: 31 additions & 15 deletions aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,10 @@ async def query(
storage = session.inject(BaseStorage)

tag_query = cls.prefix_tag_filter(tag_filter)
if limit is not None or offset is not None:
post_filter = post_filter_positive or post_filter_negative
paginated = limit is not None or offset is not None
if not post_filter and paginated:
# Only fetch paginated records if post-filter is not being applied
rows = await storage.find_paginated_records(
type_filter=cls.RECORD_TYPE,
tag_query=tag_query,
Expand All @@ -328,23 +331,36 @@ async def query(
)

result = []
num_results_post_filter = 0 # to apply pagination post-filter
num_records_to_match = (
(limit or DEFAULT_PAGE_SIZE) + (offset or 0) if paginated else sys.maxsize
) # if pagination is not requested, set to sys.maxsize to process all records
for record in rows:
vals = json.loads(record.value)
if match_post_filter(
vals,
post_filter_positive,
positive=True,
alt=alt,
) and match_post_filter(
vals,
post_filter_negative,
positive=False,
alt=alt,
):
try:
try:
if not post_filter: # pagination would already be applied if requested
result.append(cls.from_storage(record.id, vals))
except BaseModelError as err:
raise BaseModelError(f"{err}, for record id {record.id}")
elif (
(not paginated or num_results_post_filter < num_records_to_match)
and match_post_filter(
vals,
post_filter_positive,
positive=True,
alt=alt,
)
and match_post_filter(
vals,
post_filter_negative,
positive=False,
alt=alt,
)
):
if num_results_post_filter >= (offset or 0):
# append post-filtered records after requested offset
result.append(cls.from_storage(record.id, vals))
num_results_post_filter += 1
except BaseModelError as err:
raise BaseModelError(f"{err}, for record id {record.id}")
return result

async def save(
Expand Down
34 changes: 34 additions & 0 deletions aries_cloudagent/messaging/models/tests/test_base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,3 +413,37 @@ async def test_query_with_limit_and_offset(self):
assert result[0]._id == record_id
assert result[0].value == record_value
assert result[0].a == "one"

async def test_query_with_limit_and_offset_and_post_filter(self):
session = InMemoryProfile.test_session()
mock_storage = mock.MagicMock(BaseStorage, autospec=True)
session.context.injector.bind_instance(BaseStorage, mock_storage)
record_id = "record_id"
a_record = ARecordImpl(ident=record_id, a="one", b="two", code="red")
record_value = a_record.record_value
record_value.update({"created_at": time_now(), "updated_at": time_now()})
tag_filter = {"code": "red"}
stored = StorageRecord(
ARecordImpl.RECORD_TYPE,
json.dumps(record_value),
{"code": "red"},
record_id,
)
mock_storage.find_all_records.return_value = [stored] * 15 # return 15 records

# Query with limit and offset
result = await ARecordImpl.query(
session,
tag_filter,
limit=10,
offset=5,
post_filter_positive={"a": "one"},
)
mock_storage.find_all_records.assert_awaited_once_with(
type_filter=ARecordImpl.RECORD_TYPE, tag_query=tag_filter
)
assert len(result) == 10
assert result and isinstance(result[0], ARecordImpl)
assert result[0]._id == record_id
assert result[0].value == record_value
assert result[0].a == "one"

0 comments on commit 14a6301

Please sign in to comment.