Skip to content

Commit

Permalink
GroupBuilder: fix broken query kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
rkingsbury committed Oct 8, 2024
1 parent a823360 commit c0dfe95
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
17 changes: 8 additions & 9 deletions src/maggma/builders/group_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.source = source
self.target = target
self.grouping_keys = grouping_keys
self.query = query
self.query = query if query else {}
self.projection = projection
self.kwargs = kwargs
self.timeout = timeout
Expand Down Expand Up @@ -119,8 +119,9 @@ def get_items(self):

self.total = len(groups)
for group in groups:
docs = list(self.source.query(criteria=dict(zip(self.grouping_keys, group)), properties=projection))
yield docs
group_criteria = dict(zip(self.grouping_keys, group))
group_criteria.update(self.query)
yield list(self.source.query(criteria=group_criteria, properties=projection))

def process_item(self, item: list[dict]) -> dict[tuple, dict]: # type: ignore
keys = [d[self.source.key] for d in item]
Expand Down Expand Up @@ -184,9 +185,7 @@ def get_ids_to_process(self) -> Iterable:
"""
Gets the IDs that need to be processed.
"""
query = self.query or {}

distinct_from_target = list(self.target.distinct(self._target_keys_field, criteria=query))
distinct_from_target = list(self.target.distinct(self._target_keys_field, criteria=self.query))
processed_ids = []
# Not always guaranteed that MongoDB will unpack the list so we
# have to make sure we do that
Expand All @@ -196,19 +195,19 @@ def get_ids_to_process(self) -> Iterable:
else:
processed_ids.append(d)

all_ids = set(self.source.distinct(self.source.key, criteria=query))
all_ids = set(self.source.distinct(self.source.key, criteria=self.query))
self.logger.debug(f"Found {len(all_ids)} total docs in source")

if self.retry_failed:
failed_keys = self.target.distinct(self._target_keys_field, criteria={"state": "failed", **query})
failed_keys = self.target.distinct(self._target_keys_field, criteria={"state": "failed", **self.query})
unprocessed_ids = all_ids - (set(processed_ids) - set(failed_keys))
self.logger.debug(f"Found {len(failed_keys)} failed IDs in target")
else:
unprocessed_ids = all_ids - set(processed_ids)

self.logger.info(f"Found {len(unprocessed_ids)} IDs to process")

new_ids = set(self.source.newer_in(self.target, criteria=query, exhaustive=False))
new_ids = set(self.source.newer_in(self.target, criteria=self.query, exhaustive=False))

self.logger.info(f"Found {len(new_ids)} updated IDs to process")
return list(new_ids | unprocessed_ids)
Expand Down
13 changes: 8 additions & 5 deletions tests/builders/test_group_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Tests for group builder
"""

from datetime import datetime
from datetime import datetime, timezone
from random import randint

import pytest
Expand All @@ -13,7 +13,7 @@

@pytest.fixture(scope="module")
def now():
return datetime.utcnow()
return datetime.now(timezone.utc)


@pytest.fixture()
Expand Down Expand Up @@ -62,9 +62,12 @@ def unary_function(self, items: list[dict]) -> dict:


def test_grouping(source, target, docs):
builder = DummyGrouper(source, target, grouping_keys=["a"])
builder = DummyGrouper(source, target,
query={"k": {"$ne":3}},
grouping_keys=["a"]
)

assert len(docs) == len(builder.get_ids_to_process())
assert len(docs) - 1 == len(builder.get_ids_to_process()), f"{len(docs) -1} != {len(builder.get_ids_to_process())}"
assert len(builder.get_groups_from_keys([d["k"] for d in docs])) == 3

to_process = list(builder.get_items())
Expand All @@ -75,4 +78,4 @@ def test_grouping(source, target, docs):

builder.update_targets(processed)

assert len(builder.get_ids_to_process()) == 0
assert len(builder.get_ids_to_process()) == 0, f"{len(builder.get_ids_to_process())} != 0"

0 comments on commit c0dfe95

Please sign in to comment.