From 2b2a9563bfb4389d3ded50fff126b30f7f1f6497 Mon Sep 17 00:00:00 2001 From: Hardik Ojha Date: Wed, 1 Jun 2022 14:43:35 +0530 Subject: [PATCH 1/7] Built testing heirarchy --- .pre-commit-config.yaml | 1 + testing/__init__.py | 3 ++ testing/cli/__init__.py | 0 testing/cli/test_project_cli.py | 0 testing/cli/test_shell.py | 0 testing/common/__init__.py | 0 testing/common/test_aggregates.py | 0 testing/common/test_directives.py | 0 testing/common/test_environments.py | 0 testing/conftest.py | 27 ++++++++++++++++++ testing/flowproject/__init__.py | 0 testing/flowproject/test_aggregation.py | 0 testing/flowproject/test_groups.py | 0 testing/flowproject/test_hooks.py | 0 testing/flowproject/test_project.py | 0 testing/flowproject/test_status.py | 38 +++++++++++++++++++++++++ testing/flowproject/test_templates.py | 0 testing/utils/__init__.py | 0 testing/utils/test_utils.py | 0 19 files changed, 69 insertions(+) create mode 100644 testing/__init__.py create mode 100644 testing/cli/__init__.py create mode 100644 testing/cli/test_project_cli.py create mode 100644 testing/cli/test_shell.py create mode 100644 testing/common/__init__.py create mode 100644 testing/common/test_aggregates.py create mode 100644 testing/common/test_directives.py create mode 100644 testing/common/test_environments.py create mode 100644 testing/conftest.py create mode 100644 testing/flowproject/__init__.py create mode 100644 testing/flowproject/test_aggregation.py create mode 100644 testing/flowproject/test_groups.py create mode 100644 testing/flowproject/test_hooks.py create mode 100644 testing/flowproject/test_project.py create mode 100644 testing/flowproject/test_status.py create mode 100644 testing/flowproject/test_templates.py create mode 100644 testing/utils/__init__.py create mode 100644 testing/utils/test_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 654f08e38..c7c38a237 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,6 +37,7 @@ repos: exclude: | (?x)^( ^doc/| + ^testing/| ^tests/| ^flow/util/mistune/ ) diff --git a/testing/__init__.py b/testing/__init__.py new file mode 100644 index 000000000..2b7a225e4 --- /dev/null +++ b/testing/__init__.py @@ -0,0 +1,3 @@ +from .conftest import FlowProjectSetup + +__all__ = ["FlowProjectSetup"] diff --git a/testing/cli/__init__.py b/testing/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/cli/test_project_cli.py b/testing/cli/test_project_cli.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/cli/test_shell.py b/testing/cli/test_shell.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/common/__init__.py b/testing/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/common/test_aggregates.py b/testing/common/test_aggregates.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/common/test_directives.py b/testing/common/test_directives.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/common/test_environments.py b/testing/common/test_environments.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/conftest.py b/testing/conftest.py new file mode 100644 index 000000000..049c942a4 --- /dev/null +++ b/testing/conftest.py @@ -0,0 +1,27 @@ +import os +from tempfile import TemporaryDirectory + +import pytest +import signac + + +class FlowProjectSetup: + project_class = signac.Project + entrypoint = dict(path="") + project_name = None + + @pytest.fixture(autouse=True) + def _setup(self, request): + self._tmp_dir = TemporaryDirectory(prefix="signac-flow_") + request.addfinalizer(self._tmp_dir.cleanup) + self.project = self.project_class.init_project( + name=self.project_name, root=self._tmp_dir.name + ) + self.cwd = os.getcwd() + + def mock_project(self): + pass + + @pytest.fixture + def project(self): + return self.mock_project() diff --git a/testing/flowproject/__init__.py b/testing/flowproject/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/flowproject/test_aggregation.py b/testing/flowproject/test_aggregation.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/flowproject/test_groups.py b/testing/flowproject/test_groups.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/flowproject/test_hooks.py b/testing/flowproject/test_hooks.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/flowproject/test_project.py b/testing/flowproject/test_project.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/flowproject/test_status.py b/testing/flowproject/test_status.py new file mode 100644 index 000000000..10effccad --- /dev/null +++ b/testing/flowproject/test_status.py @@ -0,0 +1,38 @@ +from io import StringIO + +from flow import FlowProject +from flow.project import _AggregateStoresCursor + +from .. import FlowProjectSetup + + +class TestStatusPerformance(FlowProjectSetup): + class Project(FlowProject): + pass + + @Project.operation + @Project.post.isfile("DOES_NOT_EXIST") + def foo(job): + pass + + project_class = Project + project_name = "FlowTestProject" + + def mock_project(self): + project = self.project_class.get_project(root=self._tmp_dir.name) + for i in range(1000): + project.open_job(dict(i=i)).init() + return project + + def test_status_performance(self, project): + import timeit + + time = timeit.timeit( + lambda: project._fetch_status( + aggregates=_AggregateStoresCursor(project), + err=StringIO(), + ignore_errors=False, + ), + number=10, + ) + assert time < 10 diff --git a/testing/flowproject/test_templates.py b/testing/flowproject/test_templates.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/utils/__init__.py b/testing/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/testing/utils/test_utils.py b/testing/utils/test_utils.py new file mode 100644 index 000000000..e69de29bb From 98149fedf8a52a7a91661e6e7305b543fb81466e Mon Sep 17 00:00:00 2001 From: Hardik Ojha Date: Wed, 1 Jun 2022 23:29:52 +0530 Subject: [PATCH 2/7] Back to basics --- testing/cli/__init__.py | 0 testing/common/__init__.py | 0 testing/flowproject/__init__.py | 0 testing/{common => }/test_aggregates.py | 0 testing/{flowproject => }/test_aggregation.py | 0 testing/{common => }/test_directives.py | 0 testing/{common => }/test_environments.py | 0 testing/{flowproject => }/test_groups.py | 0 testing/{flowproject => }/test_hooks.py | 0 testing/{flowproject => }/test_project.py | 0 testing/{cli => }/test_project_cli.py | 0 testing/{cli => }/test_shell.py | 0 testing/{flowproject => }/test_status.py | 0 testing/{flowproject => }/test_templates.py | 0 testing/{utils => }/test_utils.py | 0 testing/utils/__init__.py | 0 16 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 testing/cli/__init__.py delete mode 100644 testing/common/__init__.py delete mode 100644 testing/flowproject/__init__.py rename testing/{common => }/test_aggregates.py (100%) rename testing/{flowproject => }/test_aggregation.py (100%) rename testing/{common => }/test_directives.py (100%) rename testing/{common => }/test_environments.py (100%) rename testing/{flowproject => }/test_groups.py (100%) rename testing/{flowproject => }/test_hooks.py (100%) rename testing/{flowproject => }/test_project.py (100%) rename testing/{cli => }/test_project_cli.py (100%) rename testing/{cli => }/test_shell.py (100%) rename testing/{flowproject => }/test_status.py (100%) rename testing/{flowproject => }/test_templates.py (100%) rename testing/{utils => }/test_utils.py (100%) delete mode 100644 testing/utils/__init__.py diff --git a/testing/cli/__init__.py b/testing/cli/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/common/__init__.py b/testing/common/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/flowproject/__init__.py b/testing/flowproject/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/common/test_aggregates.py b/testing/test_aggregates.py similarity index 100% rename from testing/common/test_aggregates.py rename to testing/test_aggregates.py diff --git a/testing/flowproject/test_aggregation.py b/testing/test_aggregation.py similarity index 100% rename from testing/flowproject/test_aggregation.py rename to testing/test_aggregation.py diff --git a/testing/common/test_directives.py b/testing/test_directives.py similarity index 100% rename from testing/common/test_directives.py rename to testing/test_directives.py diff --git a/testing/common/test_environments.py b/testing/test_environments.py similarity index 100% rename from testing/common/test_environments.py rename to testing/test_environments.py diff --git a/testing/flowproject/test_groups.py b/testing/test_groups.py similarity index 100% rename from testing/flowproject/test_groups.py rename to testing/test_groups.py diff --git a/testing/flowproject/test_hooks.py b/testing/test_hooks.py similarity index 100% rename from testing/flowproject/test_hooks.py rename to testing/test_hooks.py diff --git a/testing/flowproject/test_project.py b/testing/test_project.py similarity index 100% rename from testing/flowproject/test_project.py rename to testing/test_project.py diff --git a/testing/cli/test_project_cli.py b/testing/test_project_cli.py similarity index 100% rename from testing/cli/test_project_cli.py rename to testing/test_project_cli.py diff --git a/testing/cli/test_shell.py b/testing/test_shell.py similarity index 100% rename from testing/cli/test_shell.py rename to testing/test_shell.py diff --git a/testing/flowproject/test_status.py b/testing/test_status.py similarity index 100% rename from testing/flowproject/test_status.py rename to testing/test_status.py diff --git a/testing/flowproject/test_templates.py b/testing/test_templates.py similarity index 100% rename from testing/flowproject/test_templates.py rename to testing/test_templates.py diff --git a/testing/utils/test_utils.py b/testing/test_utils.py similarity index 100% rename from testing/utils/test_utils.py rename to testing/test_utils.py diff --git a/testing/utils/__init__.py b/testing/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 From ea2070a3937f5d8dc6367dc6cca2d2d2e538b65b Mon Sep 17 00:00:00 2001 From: Hardik Ojha Date: Thu, 2 Jun 2022 17:20:37 +0530 Subject: [PATCH 3/7] Add tests for aggregates.py --- .pre-commit-config.yaml | 1 - testing/__init__.py | 3 - testing/test_aggregates.py | 0 testing/test_aggregation.py | 0 testing/test_directives.py | 0 testing/test_environments.py | 0 testing/test_groups.py | 0 testing/test_hooks.py | 0 testing/test_project.py | 0 testing/test_project_cli.py | 0 testing/test_shell.py | 0 testing/test_status.py | 38 -- testing/test_templates.py | 0 testing/test_utils.py | 0 {testing => tests}/conftest.py | 0 tests/test_aggregates.py | 737 +++++++++++++++++---------------- 16 files changed, 375 insertions(+), 404 deletions(-) delete mode 100644 testing/__init__.py delete mode 100644 testing/test_aggregates.py delete mode 100644 testing/test_aggregation.py delete mode 100644 testing/test_directives.py delete mode 100644 testing/test_environments.py delete mode 100644 testing/test_groups.py delete mode 100644 testing/test_hooks.py delete mode 100644 testing/test_project.py delete mode 100644 testing/test_project_cli.py delete mode 100644 testing/test_shell.py delete mode 100644 testing/test_status.py delete mode 100644 testing/test_templates.py delete mode 100644 testing/test_utils.py rename {testing => tests}/conftest.py (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c7c38a237..654f08e38 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,6 @@ repos: exclude: | (?x)^( ^doc/| - ^testing/| ^tests/| ^flow/util/mistune/ ) diff --git a/testing/__init__.py b/testing/__init__.py deleted file mode 100644 index 2b7a225e4..000000000 --- a/testing/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .conftest import FlowProjectSetup - -__all__ = ["FlowProjectSetup"] diff --git a/testing/test_aggregates.py b/testing/test_aggregates.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_aggregation.py b/testing/test_aggregation.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_directives.py b/testing/test_directives.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_environments.py b/testing/test_environments.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_groups.py b/testing/test_groups.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_hooks.py b/testing/test_hooks.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_project.py b/testing/test_project.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_project_cli.py b/testing/test_project_cli.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_shell.py b/testing/test_shell.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_status.py b/testing/test_status.py deleted file mode 100644 index 10effccad..000000000 --- a/testing/test_status.py +++ /dev/null @@ -1,38 +0,0 @@ -from io import StringIO - -from flow import FlowProject -from flow.project import _AggregateStoresCursor - -from .. import FlowProjectSetup - - -class TestStatusPerformance(FlowProjectSetup): - class Project(FlowProject): - pass - - @Project.operation - @Project.post.isfile("DOES_NOT_EXIST") - def foo(job): - pass - - project_class = Project - project_name = "FlowTestProject" - - def mock_project(self): - project = self.project_class.get_project(root=self._tmp_dir.name) - for i in range(1000): - project.open_job(dict(i=i)).init() - return project - - def test_status_performance(self, project): - import timeit - - time = timeit.timeit( - lambda: project._fetch_status( - aggregates=_AggregateStoresCursor(project), - err=StringIO(), - ignore_errors=False, - ), - number=10, - ) - assert time < 10 diff --git a/testing/test_templates.py b/testing/test_templates.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/test_utils.py b/testing/test_utils.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/testing/conftest.py b/tests/conftest.py similarity index 100% rename from testing/conftest.py rename to tests/conftest.py diff --git a/tests/test_aggregates.py b/tests/test_aggregates.py index f086a8214..a19c64e39 100644 --- a/tests/test_aggregates.py +++ b/tests/test_aggregates.py @@ -1,362 +1,375 @@ -from functools import partial -from tempfile import TemporaryDirectory - -import pytest -import signac - -from flow.aggregates import _DefaultAggregateStore, aggregator, get_aggregate_id -from flow.errors import FlowProjectDefinitionError - - -@pytest.fixture -def list_of_aggregators(): - def helper_default_aggregator_function(jobs): - yield tuple(jobs) - - def helper_non_default_aggregator_function(jobs): - for job in jobs: - yield (job,) - - # The below list contains 14 distinct aggregator objects and some duplicates. - return [ - aggregator(), - aggregator(), - aggregator(helper_default_aggregator_function), - aggregator(helper_non_default_aggregator_function), - aggregator(helper_non_default_aggregator_function), - aggregator.groupsof(1), - aggregator.groupsof(1), - aggregator.groupsof(2), - aggregator.groupsof(3), - aggregator.groupsof(4), - aggregator.groupby("even"), - aggregator.groupby("even"), - aggregator.groupby("half", -1), - aggregator.groupby("half", -1), - aggregator.groupby(["half", "even"], default=[-1, -1]), - ] - - -class AggregateProjectSetup: - project_class = signac.Project - entrypoint = dict(path="") - - @pytest.fixture - def setUp(self, request): - self._tmp_dir = TemporaryDirectory(prefix="flow-aggregate_") - request.addfinalizer(self._tmp_dir.cleanup) - self.project = self.project_class.init_project( - name="AggregateTestProject", root=self._tmp_dir.name - ) - - def mock_project(self): - project = self.project_class.get_project(root=self._tmp_dir.name) - for i in range(10): - even = (i % 2) == 0 - if even: - project.open_job(dict(i=i, half=i / 2, even=even)).init() - else: - project.open_job(dict(i=i, even=even)).init() - return project - - @pytest.fixture - def project(self): - return self.mock_project() - - -# Test the decorator class aggregator -class TestAggregate(AggregateProjectSetup): - def test_default_init(self): - aggregate_instance = aggregator() - # Ensure that all values are converted to tuples - test_values = [(1, 2, 3, 4, 5), (), [1, 2, 3, 4, 5], []] - assert not aggregate_instance._is_default_aggregator - assert aggregate_instance._sort_by is None - assert aggregate_instance._sort_ascending - assert aggregate_instance._select is None - for value in test_values: - assert list(aggregate_instance._aggregator_function(value)) == [ - tuple(value) - ] - - def test_invalid_aggregator_function(self, setUp, project): - aggregator_functions = ["str", 1, {}] - for aggregator_function in aggregator_functions: - with pytest.raises(TypeError): - aggregator(aggregator_function) - - def test_invalid_sort_by(self): - sort_list = [1, {}] - for sort in sort_list: - with pytest.raises(TypeError): - aggregator(sort_by=sort) - - def test_invalid_select(self): - selectors = ["str", 1, []] - for _select in selectors: - with pytest.raises(TypeError): - aggregator(select=_select) - - def test_invalid_call(self): - call_params = ["str", 1, None] - for param in call_params: - with pytest.raises(FlowProjectDefinitionError): - aggregator()(param) - - def test_call_without_decorator(self): - aggregate_instance = aggregator() - with pytest.raises(FlowProjectDefinitionError): - aggregate_instance() - - def test_call_with_decorator(self): - @aggregator() - def test_function(x): - return x - - assert hasattr(test_function, "_flow_aggregate") - - def test_groups_of_invalid_num(self): - invalid_values = [{}, "str", -1, -1.5] - for invalid_value in invalid_values: - with pytest.raises((TypeError, ValueError)): - aggregator.groupsof(invalid_value) - - def test_group_by_invalid_key(self): - with pytest.raises(TypeError): - aggregator.groupby(1) - - def test_groupby_with_valid_type_default_for_Iterable(self): - aggregator.groupby(["half", "even"], default=[-1, -1]) - - def test_groupby_with_invalid_type_default_key_for_Iterable(self): - with pytest.raises(TypeError): - aggregator.groupby(["half", "even"], default=-1) - - def test_groupby_with_invalid_length_default_key_for_Iterable(self): - with pytest.raises(ValueError): - aggregator.groupby(["half", "even"], default=[-1, -1, -1]) - - def test_aggregate_hashing(self, list_of_aggregators): - # Since we need to store groups on a per aggregate basis in the project, - # we need to be sure that the aggregates are hashing and compared correctly. - # This test ensures this feature. - # list_of_aggregators contains 14 distinct store objects (because an - # aggregator object is differentiated on the basis of the `_is_aggregate` attribute). - # When this list is converted to set, then these objects are hashed first - # and then compared. Since sets don't carry duplicate values, we test - # whether the length of the set obtained from the list is equal to 14 or not. - assert len(set(list_of_aggregators)) == 14 - # Ensure that equality implies hash equality. - for agg1 in list_of_aggregators: - for agg2 in list_of_aggregators: - if agg1 == agg2: - assert hash(agg1) == hash(agg2) - - -# Test the _AggregateStore and _DefaultAggregateStore classes -class TestAggregateStore(AggregateProjectSetup): - def test_custom_aggregator_function(self, setUp, project): - # Testing aggregator function returning aggregates of 1 - def helper_aggregator_function(jobs): - for job in jobs: - yield (job,) - - aggregate_instance = aggregator(helper_aggregator_function) - aggregate_store = aggregate_instance._create_AggregateStore(project) - aggregate_job_manual = helper_aggregator_function(project) - assert tuple(aggregate_job_manual) == tuple(aggregate_store.values()) - - # Testing aggregator function returning aggregates of all the jobs - aggregate_instance = aggregator(lambda jobs: [jobs]) - aggregate_store = aggregate_instance._create_AggregateStore(project) - assert (tuple(project),) == tuple(aggregate_store.values()) - - def test_sort_by(self, setUp, project): - helper_sort = partial(sorted, key=lambda job: job.sp.i) - aggregate_instance = aggregator(sort_by="i") - aggregate_store = aggregate_instance._create_AggregateStore(project) - assert (tuple(helper_sort(project)),) == tuple(aggregate_store.values()) - - def test_sort_by_callable(self, setUp, project): - def keyfunction(job): - return job.sp.i - - helper_sort = partial(sorted, key=keyfunction) - aggregate_instance = aggregator(sort_by=keyfunction) - aggregate_store = aggregate_instance._create_AggregateStore(project) - assert (tuple(helper_sort(project)),) == tuple(aggregate_store.values()) - - def test_sort_descending(self, setUp, project): - helper_sort = partial(sorted, key=lambda job: job.sp.i, reverse=True) - aggregate_instance = aggregator(sort_by="i", sort_ascending=False) - aggregate_store = aggregate_instance._create_AggregateStore(project) - assert (tuple(helper_sort(project)),) == tuple(aggregate_store.values()) - - def test_groups_of_valid_num(self, setUp, project): - valid_values = [1, 2, 3, 6, 10] - # Expected length of aggregates which are made using the above values. - expected_length_of_aggregators = [10, 5, 4, 2, 1] - # Expect length of each aggregate which are made using the above - # values. The zeroth index of the nested list denotes the length of all - # the aggregates expect the last one. The first index denotes the - # length of the last aggregate formed. - expected_length_per_aggregate = [[1, 1], [2, 2], [3, 1], [6, 4], [10, 10]] - for i, valid_value in enumerate(valid_values): - aggregate_instance = aggregator.groupsof(valid_value) - aggregate_store = aggregate_instance._create_AggregateStore(project) - expected_len = expected_length_of_aggregators[i] - assert len(aggregate_store) == expected_len - - # We also check the length of every aggregate in order to ensure - # proper aggregation. - for j, aggregate in enumerate(aggregate_store.values()): - if j == expected_len - 1: # Checking for the last aggregate - assert len(aggregate) == expected_length_per_aggregate[i][1] - else: - assert len(aggregate) == expected_length_per_aggregate[i][0] - - def test_groupby_with_valid_string_key(self, setUp, project): - aggregate_instance = aggregator.groupby("even") - aggregate_store = aggregate_instance._create_AggregateStore(project) - for aggregate in aggregate_store.values(): - even = aggregate[0].sp.even - assert all(even == job.sp.even for job in aggregate) - assert len(aggregate_store) == 2 - - def test_groupby_with_invalid_string_key(self, setUp, project): - aggregate_instance = aggregator.groupby("invalid_key") - with pytest.raises(KeyError): - # We will attempt to generate aggregates but will fail in - # doing so due to the invalid key - aggregate_instance._create_AggregateStore(project) - - def test_groupby_with_default_key_for_string(self, setUp, project): - aggregate_instance = aggregator.groupby("half", default=-1) - aggregate_store = aggregate_instance._create_AggregateStore(project) - for aggregate in aggregate_store.values(): - half = aggregate[0].sp.get("half", -1) - assert all(half == job.sp.get("half", -1) for job in aggregate) - assert len(aggregate_store) == 6 - - def test_groupby_with_Iterable_key(self, setUp, project): - aggregate_instance = aggregator.groupby(["i", "even"]) - aggregate_store = aggregate_instance._create_AggregateStore(project) - # No aggregation takes place hence this means we don't need to check - # whether all the aggregate members are equivalent. - assert len(aggregate_store) == 10 - - def test_groupby_with_invalid_Iterable_key(self, setUp, project): - aggregate_instance = aggregator.groupby(["half", "even"]) - with pytest.raises(KeyError): - # We will attempt to generate aggregates but will fail in - # doing so due to the invalid keys - aggregate_instance._create_AggregateStore(project) - - def test_groupby_with_valid_default_key_for_Iterable(self, setUp, project): - aggregate_instance = aggregator.groupby(["half", "even"], default=[-1, -1]) - aggregate_store = aggregate_instance._create_AggregateStore(project) - for aggregate in aggregate_store.values(): - half = aggregate[0].sp.get("half", -1) - even = aggregate[0].sp.get("even", -1) - assert all( - half == job.sp.get("half", -1) and even == job.sp.get("even", -1) - for job in aggregate - ) - assert len(aggregate_store) == 6 - - def test_groupby_with_callable_key(self, setUp, project): - def keyfunction(job): - return job.sp["even"] - - aggregate_instance = aggregator.groupby(keyfunction) - aggregate_store = aggregate_instance._create_AggregateStore(project) - for aggregate in aggregate_store.values(): - even = aggregate[0].sp.even - assert all(even == job.sp.even for job in aggregate) - assert len(aggregate_store) == 2 - - def test_groupby_with_invalid_callable_key(self, setUp, project): - def keyfunction(job): - return job.sp["half"] - - aggregate_instance = aggregator.groupby(keyfunction) - with pytest.raises(KeyError): - # We will attempt to generate aggregates but will fail in - # doing so due to the invalid key - aggregate_instance._create_AggregateStore(project) - - def test_valid_select(self, setUp, project): - def _select(job): - return job.sp.i > 5 - - aggregate_instance = aggregator.groupsof(1, select=_select) - aggregate_store = aggregate_instance._create_AggregateStore(project) - selected_jobs = [] - for job in project: - if _select(job): - selected_jobs.append((job,)) - assert list(aggregate_store.values()) == selected_jobs - - def test_store_hashing(self, setUp, project, list_of_aggregators): - # Since we need to store groups on a per aggregate basis in the project, - # we need to be sure that the aggregates are hashing and compared correctly. - # This test ensures this feature. - list_of_stores = [ - aggregator._create_AggregateStore(project) - for aggregator in list_of_aggregators - ] - assert len(list_of_stores) == len(list_of_aggregators) - # The above list contains 14 distinct store objects (because a - # store object is differentiated on the basis of the - # ``_is_default_aggregate`` attribute of the aggregator). When this - # list is converted to a set, then these objects are hashed first and - # then compared. Since sets don't carry duplicate values, we test - # whether the length of the set obtained from the list is equal to 14 - # or not. - assert len(set(list_of_stores)) == 14 - - def test_aggregates_are_tuples(self, setUp, project, list_of_aggregators): - # This test ensures that all aggregator functions return tuples. All - # aggregate stores are expected to return tuples for their values, but - # this also tests that the aggregator functions (groupsof, groupby) are - # generating tuples internally. - for aggregator_instance in list_of_aggregators: - aggregate_store = aggregator_instance._create_AggregateStore(project) - if not isinstance(aggregate_store, _DefaultAggregateStore): - for aggregate in aggregate_store._generate_aggregates(): - assert isinstance(aggregate, tuple) - assert all( - isinstance(job, signac.contrib.job.Job) for job in aggregate - ) - for aggregate in aggregate_store.values(): - assert isinstance(aggregate, tuple) - assert all(isinstance(job, signac.contrib.job.Job) for job in aggregate) - - def test_get_by_id(self, setUp, project, list_of_aggregators): - # Ensure that all aggregates can be fetched by id. - for aggregator_instance in list_of_aggregators: - aggregate_store = aggregator_instance._create_AggregateStore(project) - for aggregate in aggregate_store.values(): - assert aggregate == aggregate_store[get_aggregate_id(aggregate)] - - def test_get_invalid_id(self, setUp, project): - jobs = tuple(project) - aggregator_instance = aggregator()._create_AggregateStore(project) - default_aggregator = aggregator.groupsof(1)._create_AggregateStore(project) - # Test for an aggregate of single job for aggregator instance - with pytest.raises(LookupError): - aggregator_instance[get_aggregate_id((jobs[0],))] - # Test for an aggregate of all jobs for default aggregator - with pytest.raises(LookupError): - default_aggregator[get_aggregate_id(jobs)] - - def test_contains(self, setUp, project): - jobs = tuple(project) - aggregator_instance = aggregator()._create_AggregateStore(project) - default_aggregator = aggregator.groupsof(1)._create_AggregateStore(project) - # Test for an aggregate of all jobs - assert get_aggregate_id(jobs) in aggregator_instance - assert get_aggregate_id(jobs) not in default_aggregator - # Test for an aggregate of single job - assert not jobs[0].id in aggregator_instance - assert jobs[0].id in default_aggregator +from functools import partial +from math import ceil + +import pytest +import signac +from conftest import FlowProjectSetup + +from flow.aggregates import _DefaultAggregateStore, aggregator, get_aggregate_id +from flow.errors import FlowProjectDefinitionError + + +class AggregateProjectSetup(FlowProjectSetup): + project_name = "AggregateTestProject" + + def mock_project(self): + project = self.project_class.get_project(root=self._tmp_dir.name) + for i in range(10): + even = (i % 2) == 0 + if even: + project.open_job(dict(i=i, half=i / 2, even=even)).init() + else: + project.open_job(dict(i=i, even=even)).init() + return project + + +class AggregateFixtures: + @classmethod + def _get_single_job_aggregate(cls, jobs): + for job in jobs: + yield (job,) + + @classmethod + def _get_all_job_aggregate(cls, jobs): + return (tuple(jobs),) + + @pytest.fixture + def get_single_job_aggregate(self): + return AggregateFixtures._get_single_job_aggregate + + @pytest.fixture + def get_all_job_aggregate(self): + return AggregateFixtures._get_all_job_aggregate + + @classmethod + def list_of_aggregators(cls): + # The below list contains 14 distinct aggregator objects and some duplicates. + return [ + aggregator(), + aggregator(), + aggregator(cls._get_all_job_aggregate), + aggregator(cls._get_single_job_aggregate), + aggregator(cls._get_single_job_aggregate), + aggregator.groupsof(1), + aggregator.groupsof(1), + aggregator.groupsof(2), + aggregator.groupsof(3), + aggregator.groupsof(4), + aggregator.groupby("even"), + aggregator.groupby("even"), + aggregator.groupby("half", -1), + aggregator.groupby("half", -1), + aggregator.groupby(["half", "even"], default=[-1, -1]), + ] + + def create_aggregate_store(self, aggregator_instance, project): + return aggregator_instance._create_AggregateStore(project) + + def get_aggregates_from_store(self, aggregate_store): + return tuple(aggregate_store.values()) + + +# Test the _AggregateStore and _DefaultAggregateStore classes +class TestAggregateStore(AggregateProjectSetup, AggregateFixtures): + def test_custom_aggregator_function( + self, project, get_single_job_aggregate, get_all_job_aggregate + ): + # Testing aggregator function returning aggregates of 1 + aggregate_store = self.create_aggregate_store( + aggregator(get_single_job_aggregate), project + ) + aggregate_job_manual = tuple(get_single_job_aggregate(project)) + assert self.get_aggregates_from_store(aggregate_store) == aggregate_job_manual + + # Testing aggregator function returning aggregates of all the jobs + aggregate_store = self.create_aggregate_store( + aggregator(get_all_job_aggregate), project + ) + assert self.get_aggregates_from_store(aggregate_store) == get_all_job_aggregate( + project + ) + + def test_sort_by(self, project): + helper_sort = partial(sorted, key=lambda job: job.sp.i) + aggregate_store = self.create_aggregate_store(aggregator(sort_by="i"), project) + assert self.get_aggregates_from_store(aggregate_store) == ( + tuple(helper_sort(project)), + ) + + def test_sort_by_callable(self, project): + def keyfunction(job): + return job.sp.i + + helper_sort = partial(sorted, key=keyfunction) + aggregate_store = self.create_aggregate_store( + aggregator(sort_by=keyfunction), project + ) + assert self.get_aggregates_from_store(aggregate_store) == ( + tuple(helper_sort(project)), + ) + + def test_sort_descending(self, project): + helper_sort = partial(sorted, key=lambda job: job.sp.i, reverse=True) + aggregate_store = self.create_aggregate_store( + aggregator(sort_by="i", sort_ascending=False), project + ) + assert self.get_aggregates_from_store(aggregate_store) == ( + tuple(helper_sort(project)), + ) + + @pytest.mark.parametrize("aggregate_length", [1, 2, 3, 6, 10]) + def test_groups_of_valid_num(self, project, aggregate_length): + aggregate_store = self.create_aggregate_store( + aggregator.groupsof(aggregate_length), project + ) + expected_len = ceil(10 / aggregate_length) + assert len(aggregate_store) == expected_len + + # We also check the length of every aggregate in order to ensure + # proper aggregation. + last_agg_len = ( + 10 % aggregate_length if (10 % aggregate_length != 0) else aggregate_length + ) + for j, aggregate in enumerate(aggregate_store.values()): + if j == expected_len - 1: # Checking for the last aggregate + assert len(aggregate) == last_agg_len + else: + assert len(aggregate) == aggregate_length + + def test_groupby_with_valid_string_key(self, project): + aggregate_store = self.create_aggregate_store( + aggregator.groupby("even"), project + ) + assert len(aggregate_store) == 2 + for aggregate in aggregate_store.values(): + even = aggregate[0].sp.even + assert all(even == job.sp.even for job in aggregate) + + def test_groupby_with_invalid_string_key(self, project): + with pytest.raises(KeyError): + # We will attempt to generate aggregates but will fail in + # doing so due to the invalid key + self.create_aggregate_store(aggregator.groupby("invalid_key"), project) + + def test_groupby_with_default_key_for_string(self, project): + aggregate_store = self.create_aggregate_store( + aggregator.groupby("half", default=-1), project + ) + assert len(aggregate_store) == 6 + for aggregate in aggregate_store.values(): + half = aggregate[0].sp.get("half", -1) + assert all(half == job.sp.get("half", -1) for job in aggregate) + + def test_groupby_with_Iterable_key(self, project): + aggregate_store = self.create_aggregate_store( + aggregator.groupby(["i", "even"]), project + ) + # No aggregation takes place hence this means we don't need to check + # whether all the aggregate members are equivalent. + assert len(aggregate_store) == 10 + + def test_groupby_with_invalid_Iterable_key(self, project): + with pytest.raises(KeyError): + # We will attempt to generate aggregates but will fail in + # doing so due to the invalid keys + self.create_aggregate_store(aggregator.groupby(["half", "even"]), project) + + def test_groupby_with_valid_default_key_for_Iterable(self, project): + aggregate_store = self.create_aggregate_store( + aggregator.groupby(["half", "even"], default=[-1, -1]), project + ) + assert len(aggregate_store) == 6 + for aggregate in aggregate_store.values(): + half = aggregate[0].sp.get("half", -1) + even = aggregate[0].sp.get("even", -1) + assert all( + half == job.sp.get("half", -1) and even == job.sp.get("even", -1) + for job in aggregate + ) + + def test_groupby_with_callable_key(self, project): + def keyfunction(job): + return job.sp["even"] + + aggregate_store = self.create_aggregate_store( + aggregator.groupby(keyfunction), project + ) + assert len(aggregate_store) == 2 + for aggregate in aggregate_store.values(): + even = aggregate[0].sp.even + assert all(even == job.sp.even for job in aggregate) + + def test_groupby_with_invalid_callable_key(self, project): + def keyfunction(job): + return job.sp["half"] + + with pytest.raises(KeyError): + # We will attempt to generate aggregates but will fail in + # doing so due to the invalid key + self.create_aggregate_store(aggregator.groupby(keyfunction), project) + + def test_valid_select(self, project): + def _select(job): + return job.sp.i > 5 + + aggregate_store = self.create_aggregate_store( + aggregator.groupsof(1, select=_select), project + ) + selected_jobs = [] + for job in project: + if _select(job): + selected_jobs.append((job,)) + assert self.get_aggregates_from_store(aggregate_store) == tuple(selected_jobs) + + def test_store_hashing(self, project): + # Since we need to store groups on a per aggregate basis in the project, + # we need to be sure that the aggregates are hashing and compared correctly. + # This test ensures this feature. + total_aggregators = AggregateFixtures.list_of_aggregators() + list_of_stores = [ + self.create_aggregate_store(aggregator, project) + for aggregator in total_aggregators + ] + assert len(list_of_stores) == len(total_aggregators) + # The above list contains 14 distinct store objects (because a + # store object is differentiated on the basis of the + # ``_is_default_aggregate`` attribute of the aggregator). When this + # list is converted to a set, then these objects are hashed first and + # then compared. Since sets don't carry duplicate values, we test + # whether the length of the set obtained from the list is equal to 14 + # or not. + assert len(set(list_of_stores)) == 14 + + @pytest.mark.parametrize( + "aggregator_instance", AggregateFixtures.list_of_aggregators() + ) + def test_aggregates_are_tuples(self, project, aggregator_instance): + # This test ensures that all aggregator functions return tuples. All + # aggregate stores are expected to return tuples for their values, but + # this also tests that the aggregator functions (groupsof, groupby) are + # generating tuples internally. + aggregate_store = self.create_aggregate_store(aggregator_instance, project) + if not isinstance(aggregate_store, _DefaultAggregateStore): + for aggregate in aggregate_store._generate_aggregates(): + assert isinstance(aggregate, tuple) + assert all(isinstance(job, signac.contrib.job.Job) for job in aggregate) + for aggregate in aggregate_store.values(): + assert isinstance(aggregate, tuple) + assert all(isinstance(job, signac.contrib.job.Job) for job in aggregate) + + @pytest.mark.parametrize( + "aggregator_instance", AggregateFixtures.list_of_aggregators() + ) + def test_get_by_id(self, project, aggregator_instance): + # Ensure that all aggregates can be fetched by id. + aggregate_store = self.create_aggregate_store(aggregator_instance, project) + for aggregate in aggregate_store.values(): + assert aggregate == aggregate_store[get_aggregate_id(aggregate)] + + def test_get_invalid_id(self, project): + jobs = tuple(project) + full_aggregate_store = self.create_aggregate_store(aggregator(), project) + default_aggregate_store = self.create_aggregate_store( + aggregator.groupsof(1), project + ) + # Test for an aggregate of single job for aggregator instance + with pytest.raises(LookupError): + full_aggregate_store[get_aggregate_id((jobs[0],))] + # Test for an aggregate of all jobs for default aggregator + with pytest.raises(LookupError): + default_aggregate_store[get_aggregate_id(jobs)] + + def test_contains(self, project): + jobs = tuple(project) + full_aggregate_store = self.create_aggregate_store(aggregator(), project) + default_aggregate_store = aggregator.groupsof(1)._create_AggregateStore(project) + # Test for an aggregate of all jobs + assert get_aggregate_id(jobs) in full_aggregate_store + assert get_aggregate_id(jobs) not in default_aggregate_store + # Test for an aggregate of single job + assert not jobs[0].id in full_aggregate_store + assert jobs[0].id in default_aggregate_store + + +# Test the decorator class aggregator +class TestAggregate(AggregateProjectSetup, AggregateFixtures): + @pytest.mark.parametrize("agg_value", [(1, 2, 3, 4, 5), (), [1, 2, 3, 4, 5], []]) + def test_default_init(self, agg_value): + aggregate_instance = aggregator() + assert not aggregate_instance._is_default_aggregator + assert aggregate_instance._sort_by is None + assert aggregate_instance._sort_ascending + assert aggregate_instance._select is None + assert list(aggregate_instance._aggregator_function(agg_value)) == [ + tuple(agg_value) + ] + + @pytest.mark.parametrize("aggregator_function", ["str", 1, {}]) + def test_invalid_aggregator_function(self, aggregator_function): + with pytest.raises(TypeError): + aggregator(aggregator_function) + + @pytest.mark.parametrize("sort_by", [1, {}]) + def test_invalid_sort_by(self, sort_by): + with pytest.raises(TypeError): + aggregator(sort_by=sort_by) + + @pytest.mark.parametrize("select", ["str", 1, []]) + def test_invalid_select(self, select): + with pytest.raises(TypeError): + aggregator(select=select) + + @pytest.mark.parametrize("param", ["str", 1, None]) + def test_invalid_call(self, param): + aggregator_instance = aggregator() + with pytest.raises(FlowProjectDefinitionError): + aggregator_instance(param) + + def test_call_without_decorator(self): + aggregate_instance = aggregator() + with pytest.raises(FlowProjectDefinitionError): + aggregate_instance() + + def test_call_with_decorator(self): + @aggregator() + def test_function(x): + return x + + assert hasattr(test_function, "_flow_aggregate") + + @pytest.mark.parametrize("invalid_value", [{}, "str", -1, -1.5]) + def test_groups_of_invalid_num(self, invalid_value): + with pytest.raises((TypeError, ValueError)): + aggregator.groupsof(invalid_value) + + def test_group_by_invalid_key(self): + with pytest.raises(TypeError): + aggregator.groupby(1) + + def test_groupby_with_valid_type_default_for_Iterable(self): + aggregator.groupby(["half", "even"], default=[-1, -1]) + + def test_groupby_with_invalid_type_default_key_for_Iterable(self): + with pytest.raises(TypeError): + aggregator.groupby(["half", "even"], default=-1) + + def test_groupby_with_invalid_length_default_key_for_Iterable(self): + with pytest.raises(ValueError): + aggregator.groupby(["half", "even"], default=[-1, -1, -1]) + + def test_aggregate_hashing(self): + # Since we need to store groups on a per aggregate basis in the project, + # we need to be sure that the aggregates are hashing and compared correctly. + # This test ensures this feature. + # list_of_aggregators contains 14 distinct store objects (because an + # aggregator object is differentiated on the basis of the `_is_aggregate` attribute). + # When this list is converted to set, then these objects are hashed first + # and then compared. Since sets don't carry duplicate values, we test + # whether the length of the set obtained from the list is equal to 14 or not. + total_aggregators = AggregateFixtures.list_of_aggregators() + assert len(set(total_aggregators)) == 14 + # Ensure that equality implies hash equality. + for agg1 in total_aggregators: + for agg2 in total_aggregators: + if agg1 == agg2: + assert hash(agg1) == hash(agg2) From d011d6aa36bb4be7753b4b4287c5a7abf33f5df4 Mon Sep 17 00:00:00 2001 From: Hardik Ojha <44747868+kidrahahjo@users.noreply.github.com> Date: Thu, 2 Jun 2022 17:23:35 +0530 Subject: [PATCH 4/7] Update requirements/requirements-test.txt --- requirements/requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 59115283c..4056180ac 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,5 +1,5 @@ click==8.1.3 -coverage==6.4 +coverage==6.3.2 pytest-cov==3.0.0 pytest==7.1.2 ruamel.yaml==0.17.21 From 65f3aadce4747d5d9ea211d2f80bb834fa4102f8 Mon Sep 17 00:00:00 2001 From: Hardik Ojha <44747868+kidrahahjo@users.noreply.github.com> Date: Thu, 2 Jun 2022 17:23:41 +0530 Subject: [PATCH 5/7] Update requirements/requirements-dev.txt --- requirements/requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index d91cb7f24..66a69bade 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,3 +1,3 @@ click>=7.0 ruamel.yaml>=0.16.12 -pre-commit==2.19.0 +pre-commit==2.18.1 From a0749a423ad5cba5d893141f9246c822a7e73488 Mon Sep 17 00:00:00 2001 From: Hardik Ojha Date: Thu, 11 Aug 2022 01:49:25 +0530 Subject: [PATCH 6/7] Resolve review comments --- tests/test_aggregates.py | 83 +++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 47 deletions(-) diff --git a/tests/test_aggregates.py b/tests/test_aggregates.py index 065957f23..5b5b49a42 100644 --- a/tests/test_aggregates.py +++ b/tests/test_aggregates.py @@ -5,7 +5,7 @@ import signac from conftest import TestProjectBase -from flow.aggregates import _DefaultAggregateStore, aggregator, get_aggregate_id +from flow.aggregates import aggregator, get_aggregate_id from flow.errors import FlowProjectDefinitionError @@ -15,11 +15,11 @@ class AggregateProjectSetup(TestProjectBase): def mock_project(self): project = self.project_class.get_project(root=self._tmp_dir.name) for i in range(10): - even = (i % 2) == 0 - if even: - project.open_job(dict(i=i, half=i / 2, even=even)).init() + is_even = (i % 2) == 0 + if is_even: + project.open_job(dict(i=i, half=i / 2, is_even=is_even)).init() else: - project.open_job(dict(i=i, even=even)).init() + project.open_job(dict(i=i, is_even=is_even)).init() return project @pytest.fixture @@ -59,11 +59,11 @@ def list_of_aggregators(cls): aggregator.groupsof(2), aggregator.groupsof(3), aggregator.groupsof(4), - aggregator.groupby("even"), - aggregator.groupby("even"), + aggregator.groupby("is_even"), + aggregator.groupby("is_even"), aggregator.groupby("half", -1), aggregator.groupby("half", -1), - aggregator.groupby(["half", "even"], default=[-1, -1]), + aggregator.groupby(["half", "is_even"], default=[-1, -1]), ] def create_aggregate_store(self, aggregator_instance, mocked_project): @@ -128,28 +128,26 @@ def test_groups_of_valid_num(self, mocked_project, aggregate_length): aggregate_store = self.create_aggregate_store( aggregator.groupsof(aggregate_length), mocked_project ) - expected_len = ceil(10 / aggregate_length) - assert len(aggregate_store) == expected_len + expected_length = ceil(10 / aggregate_length) + assert len(aggregate_store) == expected_length # We also check the length of every aggregate in order to ensure # proper aggregation. - last_agg_len = ( + last_agg_length = ( 10 % aggregate_length if (10 % aggregate_length != 0) else aggregate_length ) - for j, aggregate in enumerate(aggregate_store.values()): - if j == expected_len - 1: # Checking for the last aggregate - assert len(aggregate) == last_agg_len - else: - assert len(aggregate) == aggregate_length + aggregates = list(aggregate_store.values()) + assert all(len(agg) == aggregate_length for agg in aggregates[:-1]) + assert len(aggregates[-1]) == last_agg_length def test_groupby_with_valid_string_key(self, mocked_project): aggregate_store = self.create_aggregate_store( - aggregator.groupby("even"), mocked_project + aggregator.groupby("is_even"), mocked_project ) assert len(aggregate_store) == 2 for aggregate in aggregate_store.values(): - even = aggregate[0].sp.even - assert all(even == job.sp.even for job in aggregate) + even = aggregate[0].sp.is_even + assert all(even == job.sp.is_even for job in aggregate) def test_groupby_with_invalid_string_key(self, mocked_project): with pytest.raises(KeyError): @@ -170,44 +168,44 @@ def test_groupby_with_default_key_for_string(self, mocked_project): def test_groupby_with_Iterable_key(self, mocked_project): aggregate_store = self.create_aggregate_store( - aggregator.groupby(["i", "even"]), mocked_project + aggregator.groupby(["i", "is_even"]), mocked_project ) # No aggregation takes place hence this means we don't need to check # whether all the aggregate members are equivalent. - assert len(aggregate_store) == 10 + assert len(aggregate_store) == len(mocked_project) def test_groupby_with_invalid_Iterable_key(self, mocked_project): with pytest.raises(KeyError): # We will attempt to generate aggregates but will fail in # doing so due to the invalid keys self.create_aggregate_store( - aggregator.groupby(["half", "even"]), mocked_project + aggregator.groupby(["half", "is_even"]), mocked_project ) def test_groupby_with_valid_default_key_for_Iterable(self, mocked_project): aggregate_store = self.create_aggregate_store( - aggregator.groupby(["half", "even"], default=[-1, -1]), mocked_project + aggregator.groupby(["half", "is_even"], default=[-1, -1]), mocked_project ) assert len(aggregate_store) == 6 for aggregate in aggregate_store.values(): half = aggregate[0].sp.get("half", -1) - even = aggregate[0].sp.get("even", -1) + even = aggregate[0].sp["is_even"] assert all( - half == job.sp.get("half", -1) and even == job.sp.get("even", -1) + half == job.sp.get("half", -1) and even == job.sp["is_even"] for job in aggregate ) def test_groupby_with_callable_key(self, mocked_project): def keyfunction(job): - return job.sp["even"] + return job.sp["is_even"] aggregate_store = self.create_aggregate_store( aggregator.groupby(keyfunction), mocked_project ) assert len(aggregate_store) == 2 for aggregate in aggregate_store.values(): - even = aggregate[0].sp.even - assert all(even == job.sp.even for job in aggregate) + even = aggregate[0].sp.is_even + assert all(even == job.sp.is_even for job in aggregate) def test_groupby_with_invalid_callable_key(self, mocked_project): def keyfunction(job): @@ -219,28 +217,22 @@ def keyfunction(job): self.create_aggregate_store(aggregator.groupby(keyfunction), mocked_project) def test_valid_select(self, mocked_project): - def _select(job): + def select(job): return job.sp.i > 5 aggregate_store = self.create_aggregate_store( - aggregator.groupsof(1, select=_select), mocked_project + aggregator.groupsof(1, select=select), mocked_project ) - selected_jobs = [] - for job in mocked_project: - if _select(job): - selected_jobs.append((job,)) - assert self.get_aggregates_from_store(aggregate_store) == tuple(selected_jobs) + selected_jobs = tuple((job,) for job in mocked_project if select(job)) + assert self.get_aggregates_from_store(aggregate_store) == selected_jobs def test_store_hashing(self, mocked_project): # Since we need to store groups on a per aggregate basis in the mocked_project, # we need to be sure that the aggregates are hashing and compared correctly. - # This test ensures this feature. - total_aggregators = AggregateFixtures.list_of_aggregators() list_of_stores = [ self.create_aggregate_store(aggregator, mocked_project) - for aggregator in total_aggregators + for aggregator in AggregateFixtures.list_of_aggregators() ] - assert len(list_of_stores) == len(total_aggregators) # The above list contains 14 distinct store objects (because a # store object is differentiated on the basis of the # ``_is_default_aggregate`` attribute of the aggregator). When this @@ -261,10 +253,6 @@ def test_aggregates_are_tuples(self, mocked_project, aggregator_instance): aggregate_store = self.create_aggregate_store( aggregator_instance, mocked_project ) - if not isinstance(aggregate_store, _DefaultAggregateStore): - for aggregate in aggregate_store._generate_aggregates(): - assert isinstance(aggregate, tuple) - assert all(isinstance(job, signac.contrib.job.Job) for job in aggregate) for aggregate in aggregate_store.values(): assert isinstance(aggregate, tuple) assert all(isinstance(job, signac.contrib.job.Job) for job in aggregate) @@ -316,6 +304,7 @@ def test_default_init(self, agg_value): assert aggregate_instance._sort_by is None assert aggregate_instance._sort_ascending assert aggregate_instance._select is None + # Test if default aggregator aggregates everything in a single group assert list(aggregate_instance._aggregator_function(agg_value)) == [ tuple(agg_value) ] @@ -341,7 +330,7 @@ def test_invalid_call(self, param): with pytest.raises(FlowProjectDefinitionError): aggregator_instance(param) - def test_call_without_decorator(self): + def test_call_without_argument(self): aggregate_instance = aggregator() with pytest.raises(FlowProjectDefinitionError): aggregate_instance() @@ -363,15 +352,15 @@ def test_group_by_invalid_key(self): aggregator.groupby(1) def test_groupby_with_valid_type_default_for_Iterable(self): - aggregator.groupby(["half", "even"], default=[-1, -1]) + aggregator.groupby(["half", "is_even"], default=[-1, -1]) def test_groupby_with_invalid_type_default_key_for_Iterable(self): with pytest.raises(TypeError): - aggregator.groupby(["half", "even"], default=-1) + aggregator.groupby(["half", "is_even"], default=-1) def test_groupby_with_invalid_length_default_key_for_Iterable(self): with pytest.raises(ValueError): - aggregator.groupby(["half", "even"], default=[-1, -1, -1]) + aggregator.groupby(["half", "is_even"], default=[-1, -1, -1]) def test_aggregate_hashing(self): # Since we need to store groups on a per aggregate basis in the project, From 3e224541a27f4068fdcf11a938511c61dad78529 Mon Sep 17 00:00:00 2001 From: Hardik Ojha Date: Sat, 20 Aug 2022 10:10:54 +0530 Subject: [PATCH 7/7] address review comments --- tests/test_aggregates.py | 78 +++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 46 deletions(-) diff --git a/tests/test_aggregates.py b/tests/test_aggregates.py index 5b5b49a42..bda96071f 100644 --- a/tests/test_aggregates.py +++ b/tests/test_aggregates.py @@ -2,7 +2,6 @@ from math import ceil import pytest -import signac from conftest import TestProjectBase from flow.aggregates import aggregator, get_aggregate_id @@ -28,32 +27,29 @@ def mocked_project(self): class AggregateFixtures: - @classmethod - def _get_single_job_aggregate(cls, jobs): - for job in jobs: - yield (job,) - - @classmethod - def _get_all_job_aggregate(cls, jobs): - return (tuple(jobs),) - @pytest.fixture def get_single_job_aggregate(self): - return AggregateFixtures._get_single_job_aggregate + def generator(jobs): + yield from ((job,) for job in jobs) + + return generator @pytest.fixture def get_all_job_aggregate(self): - return AggregateFixtures._get_all_job_aggregate + def generator(jobs): + return (tuple(jobs),) - @classmethod - def list_of_aggregators(cls): + return generator + + @pytest.fixture + def list_of_aggregators(self, get_single_job_aggregate, get_all_job_aggregate): # The below list contains 14 distinct aggregator objects and some duplicates. return [ aggregator(), aggregator(), - aggregator(cls._get_all_job_aggregate), - aggregator(cls._get_single_job_aggregate), - aggregator(cls._get_single_job_aggregate), + aggregator(get_all_job_aggregate), + aggregator(get_single_job_aggregate), + aggregator(get_single_job_aggregate), aggregator.groupsof(1), aggregator.groupsof(1), aggregator.groupsof(2), @@ -226,12 +222,12 @@ def select(job): selected_jobs = tuple((job,) for job in mocked_project if select(job)) assert self.get_aggregates_from_store(aggregate_store) == selected_jobs - def test_store_hashing(self, mocked_project): + def test_store_hashing(self, mocked_project, list_of_aggregators): # Since we need to store groups on a per aggregate basis in the mocked_project, # we need to be sure that the aggregates are hashing and compared correctly. list_of_stores = [ self.create_aggregate_store(aggregator, mocked_project) - for aggregator in AggregateFixtures.list_of_aggregators() + for aggregator in list_of_aggregators ] # The above list contains 14 distinct store objects (because a # store object is differentiated on the basis of the @@ -243,22 +239,17 @@ def test_store_hashing(self, mocked_project): assert len(set(list_of_stores)) == 14 @pytest.mark.parametrize( - "aggregator_instance", AggregateFixtures.list_of_aggregators() - ) - def test_aggregates_are_tuples(self, mocked_project, aggregator_instance): - # This test ensures that all aggregator functions return tuples. All - # aggregate stores are expected to return tuples for their values, but - # this also tests that the aggregator functions (groupsof, groupby) are - # generating tuples internally. - aggregate_store = self.create_aggregate_store( - aggregator_instance, mocked_project - ) - for aggregate in aggregate_store.values(): - assert isinstance(aggregate, tuple) - assert all(isinstance(job, signac.contrib.job.Job) for job in aggregate) - - @pytest.mark.parametrize( - "aggregator_instance", AggregateFixtures.list_of_aggregators() + "aggregator_instance", + [ + aggregator(), + aggregator.groupsof(1), + aggregator.groupsof(2), + aggregator.groupsof(3), + aggregator.groupsof(4), + aggregator.groupby("is_even"), + aggregator.groupby("half", -1), + aggregator.groupby(["half", "is_even"], default=[-1, -1]), + ], ) def test_get_by_id(self, mocked_project, aggregator_instance): # Ensure that all aggregates can be fetched by id. @@ -297,17 +288,13 @@ def test_contains(self, mocked_project): # Test the decorator class aggregator class TestAggregate(AggregateProjectSetup, AggregateFixtures): - @pytest.mark.parametrize("agg_value", [(1, 2, 3, 4, 5), (), [1, 2, 3, 4, 5], []]) - def test_default_init(self, agg_value): + def test_default_init(self): aggregate_instance = aggregator() assert not aggregate_instance._is_default_aggregator assert aggregate_instance._sort_by is None assert aggregate_instance._sort_ascending assert aggregate_instance._select is None - # Test if default aggregator aggregates everything in a single group - assert list(aggregate_instance._aggregator_function(agg_value)) == [ - tuple(agg_value) - ] + assert next(aggregate_instance._aggregator_function((1, 2, 3))) == (1, 2, 3) @pytest.mark.parametrize("aggregator_function", ["str", 1, {}]) def test_invalid_aggregator_function(self, aggregator_function): @@ -362,7 +349,7 @@ def test_groupby_with_invalid_length_default_key_for_Iterable(self): with pytest.raises(ValueError): aggregator.groupby(["half", "is_even"], default=[-1, -1, -1]) - def test_aggregate_hashing(self): + def test_aggregate_hashing(self, list_of_aggregators): # Since we need to store groups on a per aggregate basis in the project, # we need to be sure that the aggregates are hashing and compared correctly. # This test ensures this feature. @@ -371,10 +358,9 @@ def test_aggregate_hashing(self): # When this list is converted to set, then these objects are hashed first # and then compared. Since sets don't carry duplicate values, we test # whether the length of the set obtained from the list is equal to 14 or not. - total_aggregators = AggregateFixtures.list_of_aggregators() - assert len(set(total_aggregators)) == 14 + assert len(set(list_of_aggregators)) == 14 # Ensure that equality implies hash equality. - for agg1 in total_aggregators: - for agg2 in total_aggregators: + for agg1 in list_of_aggregators: + for agg2 in list_of_aggregators: if agg1 == agg2: assert hash(agg1) == hash(agg2)