From c0dfe95bd2e5957a31ca2b8e20e50aaf6f1f1f61 Mon Sep 17 00:00:00 2001 From: Ryan Kingsbury Date: Tue, 8 Oct 2024 15:27:00 -0400 Subject: [PATCH] GroupBuilder: fix broken query kwarg --- src/maggma/builders/group_builder.py | 17 ++++++++--------- tests/builders/test_group_builder.py | 13 ++++++++----- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/maggma/builders/group_builder.py b/src/maggma/builders/group_builder.py index 6d58f4006..64993a1f3 100644 --- a/src/maggma/builders/group_builder.py +++ b/src/maggma/builders/group_builder.py @@ -57,7 +57,7 @@ def __init__( self.source = source self.target = target self.grouping_keys = grouping_keys - self.query = query + self.query = query if query else {} self.projection = projection self.kwargs = kwargs self.timeout = timeout @@ -119,8 +119,9 @@ def get_items(self): self.total = len(groups) for group in groups: - docs = list(self.source.query(criteria=dict(zip(self.grouping_keys, group)), properties=projection)) - yield docs + group_criteria = dict(zip(self.grouping_keys, group)) + group_criteria.update(self.query) + yield list(self.source.query(criteria=group_criteria, properties=projection)) def process_item(self, item: list[dict]) -> dict[tuple, dict]: # type: ignore keys = [d[self.source.key] for d in item] @@ -184,9 +185,7 @@ def get_ids_to_process(self) -> Iterable: """ Gets the IDs that need to be processed. """ - query = self.query or {} - - distinct_from_target = list(self.target.distinct(self._target_keys_field, criteria=query)) + distinct_from_target = list(self.target.distinct(self._target_keys_field, criteria=self.query)) processed_ids = [] # Not always guaranteed that MongoDB will unpack the list so we # have to make sure we do that @@ -196,11 +195,11 @@ def get_ids_to_process(self) -> Iterable: else: processed_ids.append(d) - all_ids = set(self.source.distinct(self.source.key, criteria=query)) + all_ids = set(self.source.distinct(self.source.key, criteria=self.query)) self.logger.debug(f"Found {len(all_ids)} total docs in source") if self.retry_failed: - failed_keys = self.target.distinct(self._target_keys_field, criteria={"state": "failed", **query}) + failed_keys = self.target.distinct(self._target_keys_field, criteria={"state": "failed", **self.query}) unprocessed_ids = all_ids - (set(processed_ids) - set(failed_keys)) self.logger.debug(f"Found {len(failed_keys)} failed IDs in target") else: @@ -208,7 +207,7 @@ def get_ids_to_process(self) -> Iterable: self.logger.info(f"Found {len(unprocessed_ids)} IDs to process") - new_ids = set(self.source.newer_in(self.target, criteria=query, exhaustive=False)) + new_ids = set(self.source.newer_in(self.target, criteria=self.query, exhaustive=False)) self.logger.info(f"Found {len(new_ids)} updated IDs to process") return list(new_ids | unprocessed_ids) diff --git a/tests/builders/test_group_builder.py b/tests/builders/test_group_builder.py index b3e18295f..33002ef1c 100644 --- a/tests/builders/test_group_builder.py +++ b/tests/builders/test_group_builder.py @@ -2,7 +2,7 @@ Tests for group builder """ -from datetime import datetime +from datetime import datetime, timezone from random import randint import pytest @@ -13,7 +13,7 @@ @pytest.fixture(scope="module") def now(): - return datetime.utcnow() + return datetime.now(timezone.utc) @pytest.fixture() @@ -62,9 +62,12 @@ def unary_function(self, items: list[dict]) -> dict: def test_grouping(source, target, docs): - builder = DummyGrouper(source, target, grouping_keys=["a"]) + builder = DummyGrouper(source, target, + query={"k": {"$ne":3}}, + grouping_keys=["a"] + ) - assert len(docs) == len(builder.get_ids_to_process()) + assert len(docs) - 1 == len(builder.get_ids_to_process()), f"{len(docs) -1} != {len(builder.get_ids_to_process())}" assert len(builder.get_groups_from_keys([d["k"] for d in docs])) == 3 to_process = list(builder.get_items()) @@ -75,4 +78,4 @@ def test_grouping(source, target, docs): builder.update_targets(processed) - assert len(builder.get_ids_to_process()) == 0 + assert len(builder.get_ids_to_process()) == 0, f"{len(builder.get_ids_to_process())} != 0"