diff --git a/src/pytest_split/plugin.py b/src/pytest_split/plugin.py index 5db7676..3ac4ff0 100644 --- a/src/pytest_split/plugin.py +++ b/src/pytest_split/plugin.py @@ -1,6 +1,7 @@ import json +import pytest import os -from collections import defaultdict, OrderedDict +from collections import defaultdict, OrderedDict, namedtuple from typing import TYPE_CHECKING from _pytest.config import create_terminal_writer @@ -16,6 +17,9 @@ # Ugly hacks for freezegun compatibility: https://github.com/spulec/freezegun/issues/286 STORE_DURATIONS_SETUP_AND_TEARDOWN_THRESHOLD = 60 * 10 # seconds +TestGroup = namedtuple("TestGroup", "index, num_tests") +TestSuite = namedtuple("TestSuite", "splits, num_tests") + def pytest_addoption(parser: "Parser") -> None: group = parser.getgroup( @@ -52,39 +56,88 @@ def pytest_addoption(parser: "Parser") -> None: ) -def pytest_collection_modifyitems(config: "Config", items: "List[nodes.Item]") -> None: - splits = config.option.splits - group = config.option.group - store_durations = config.option.store_durations - durations_report_path = config.option.durations_path +@pytest.mark.tryfirst +def pytest_cmdline_main(config: "Config") -> None: + group = config.getoption("group") + splits = config.getoption("splits") + + if splits is None and group is None: + return + + if splits and group is None: + raise pytest.UsageError("argument `--group` is required") + + if group and splits is None: + raise pytest.UsageError("argument `--splits` is required") + + if splits < 1: + raise pytest.UsageError("argument `--splits` must be >= 1") + + if group < 1 or group > splits: + raise pytest.UsageError(f"argument `--group` must be >= 1 and <= {splits}") + + +class SplitPlugin: + def __init__(self): + self._suite: TestSuite + self._group: TestGroup + self._messages: "List[str]" = [] + + def pytest_report_collectionfinish(self, config: "Config") -> "List[str]": + lines = [] + if self._messages: + lines += self._messages + + if hasattr(self, "_suite"): + lines.append( + f"Running group {self._group.index}/{self._suite.splits}" + f" ({self._group.num_tests}/{self._suite.num_tests}) tests" + ) + + prefix = "[pytest-split]" + lines = [f"{prefix} {m}" for m in lines] + + return lines + + def pytest_collection_modifyitems( + self, config: "Config", items: "List[nodes.Item]" + ) -> None: + splits = config.option.splits + group = config.option.group + store_durations = config.option.store_durations + durations_report_path = config.option.durations_path - if any((splits, group)): - if not all((splits, group)): + if store_durations: + if any((group, splits)): + self._messages.append( + "Not splitting tests because we are storing durations" + ) return None - if not os.path.isfile(durations_report_path): + + if not group and not splits: + # don't split unless explicitly requested return None - if store_durations: - # Don't split if we are storing durations + + if not os.path.isfile(durations_report_path): + self._messages.append( + "Not splitting tests because the durations_report is missing" + ) return None - total_tests_count = len(items) - if splits and group: + with open(durations_report_path) as f: stored_durations = OrderedDict(json.load(f)) start_idx, end_idx = _calculate_suite_start_and_end_idx( splits, group, items, stored_durations ) + + self._suite = TestSuite(splits, len(items)) + self._group = TestGroup(group, end_idx - start_idx) items[:] = items[start_idx:end_idx] - terminal_reporter = config.pluginmanager.get_plugin("terminalreporter") - terminal_writer = create_terminal_writer(config) - message = terminal_writer.markup( - " Running group {}/{} ({}/{} tests)\n".format( - group, splits, len(items), total_tests_count - ) - ) - terminal_reporter.write(message) - return None + +def pytest_configure(config: "Config") -> None: + config.pluginmanager.register(SplitPlugin()) def pytest_sessionfinish(session: "Session") -> None: diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 9f3eee6..4a769ec 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -155,16 +155,6 @@ def test_it_does_not_split_with_invalid_args(self, example_suite, durations_path with open(durations_path, "w") as f: json.dump(durations, f) - result = example_suite.inline_run( - "--splits", "2", "--durations-path", durations_path - ) # no --group - result.assertoutcome(passed=10) - - result = example_suite.inline_run( - "--group", "2", "--durations-path", durations_path - ) # no --splits - result.assertoutcome(passed=10) - result = example_suite.inline_run( "--splits", "2", "--group", "1" ) # no durations report in default location @@ -214,5 +204,104 @@ def test_it_adapts_splits_based_on_new_and_deleted_tests( ] +class TestRaisesUsageErrors: + def test_returns_nonzero_when_group_but_not_splits(self, example_suite, capsys): + result = example_suite.inline_run("--group", "1") + assert result.ret == 4 + + outerr = capsys.readouterr() + assert "argument `--splits` is required" in outerr.err + + def test_returns_nonzero_when_splits_but_not_group(self, example_suite, capsys): + result = example_suite.inline_run("--splits", "1") + assert result.ret == 4 + + outerr = capsys.readouterr() + assert "argument `--group` is required" in outerr.err + + def test_returns_nonzero_when_group_below_one(self, example_suite, capsys): + result = example_suite.inline_run("--splits", "3", "--group", "0") + assert result.ret == 4 + + outerr = capsys.readouterr() + assert "argument `--group` must be >= 1 and <= 3" in outerr.err + + def test_returns_nonzero_when_group_larger_than_splits(self, example_suite, capsys): + result = example_suite.inline_run("--splits", "3", "--group", "4") + assert result.ret == 4 + + outerr = capsys.readouterr() + assert "argument `--group` must be >= 1 and <= 3" in outerr.err + + def test_returns_nonzero_when_splits_below_one(self, example_suite, capsys): + result = example_suite.inline_run("--splits", "0", "--group", "1") + assert result.ret == 4 + + outerr = capsys.readouterr() + assert "argument `--splits` must be >= 1" in outerr.err + + +class TestHasExpectedOutput: + def test_does_not_print_splitting_summary_when_durations_missing( + self, example_suite, capsys + ): + result = example_suite.inline_run("--splits", "1", "--group", "1") + assert result.ret == 0 + + outerr = capsys.readouterr() + assert ( + "[pytest-split] Not splitting tests because the durations_report is missing" + in outerr.out + ) + assert "[pytest-split] Running group" not in outerr.out + + def test_prints_splitting_summary_when_durations_present( + self, example_suite, capsys, durations_path + ): + test_name = "test_prints_splitting_summary_when_durations_present" + with open(durations_path, "w") as f: + json.dump([[f"{test_name}0/{test_name}.py::test_1", 0.5]], f) + result = example_suite.inline_run( + "--splits", "1", "--group", "1", "--durations-path", durations_path + ) + assert result.ret == 0 + + outerr = capsys.readouterr() + assert "[pytest-split] Running group 1/1 (10/10) tests" in outerr.out + + def test_prints_splitting_summary_when_storing_durations( + self, example_suite, capsys, durations_path + ): + test_name = "test_prints_splitting_summary_when_storing_durations" + with open(durations_path, "w") as f: + json.dump([[f"{test_name}0/{test_name}.py::test_1", 0.5]], f) + + result = example_suite.inline_run( + "--splits", + "1", + "--group", + "1", + "--durations-path", + durations_path, + "--store-durations", + ) + assert result.ret == 0 + + outerr = capsys.readouterr() + assert ( + "[pytest-split] Not splitting tests because we are storing durations" + in outerr.out + ) + + def test_does_not_print_splitting_summary_when_no_pytest_split_arguments( + self, example_suite, capsys + ): + result = example_suite.inline_run() + assert result.ret == 0 + + outerr = capsys.readouterr() + assert "[pytest-split]" not in outerr.out + + def _passed_test_names(result): return [passed.nodeid.split("::")[-1] for passed in result.listoutcomes()[0]]