Skip to content

Commit

Permalink
Merge pull request #14 from sondrelg/write-to-cache
Browse files Browse the repository at this point in the history
Enable test duration writes on all runs
  • Loading branch information
jerry-git authored Jun 8, 2021
2 parents 8728d94 + 604948d commit beed24d
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 257 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.6", "3.7", "3.8", "3.9", ] # "3.10.0-beta.1"
pytest-version: [ "4", "5", "6" ]
python-version: [ "3.6", "3.7", "3.8", "3.9" ]
pytest-version: [ "5", "6" ]
steps:
- name: Check out repository
uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[tool.black]
line-length = 88
line-length = 120
include = '\.pyi?$'
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[flake8]
max-line-length = 88
max-line-length = 120
ignore = ANN101, W503
select =
# B: flake8-bugbear
B,
Expand Down
279 changes: 159 additions & 120 deletions src/pytest_split/plugin.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
import json
import pytest
import os
from collections import defaultdict, OrderedDict, namedtuple
from typing import TYPE_CHECKING

from _pytest.config import create_terminal_writer
import pytest
from _pytest.config import create_terminal_writer, hookimpl
from _pytest.reports import TestReport

if TYPE_CHECKING:
from typing import List, Tuple
from _pytest.config.argparsing import Parser
from _pytest.main import Session
from typing import List, Tuple, Optional, Union

from _pytest import nodes

from _pytest.main import ExitCode
from _pytest.config import Config
from _pytest.config.argparsing import Parser

# Ugly hacks for freezegun compatibility: https://github.com/spulec/freezegun/issues/286
# Ugly hack for freezegun compatibility: https://github.com/spulec/freezegun/issues/286
STORE_DURATIONS_SETUP_AND_TEARDOWN_THRESHOLD = 60 * 10 # seconds

TestGroup = namedtuple("TestGroup", "index, num_tests")
TestSuite = namedtuple("TestSuite", "splits, num_tests")


def pytest_addoption(parser: "Parser") -> None:
"""
Declare pytest-split's options.
"""
group = parser.getgroup(
"Split tests into groups which execution time is about the same. "
"Run first the whole suite with --store-durations to save information "
"about test execution times"
"Run with --store-durations to store information about test execution times."
)
group.addoption(
"--store-durations",
dest="store_durations",
action="store_true",
help="Store durations into '--durations-path'",
help="Store durations into '--durations-path'.",
)
group.addoption(
"--durations-path",
Expand All @@ -57,12 +57,15 @@ def pytest_addoption(parser: "Parser") -> None:


@pytest.mark.tryfirst
def pytest_cmdline_main(config: "Config") -> None:
def pytest_cmdline_main(config: "Config") -> "Optional[Union[int, ExitCode]]":
"""
Validate options.
"""
group = config.getoption("group")
splits = config.getoption("splits")

if splits is None and group is None:
return
return None

if splits and group is None:
raise pytest.UsageError("argument `--group` is required")
Expand All @@ -76,130 +79,166 @@ def pytest_cmdline_main(config: "Config") -> None:
if group < 1 or group > splits:
raise pytest.UsageError(f"argument `--group` must be >= 1 and <= {splits}")

return None

class SplitPlugin:
def __init__(self):
self._suite: TestSuite
self._group: TestGroup
self._messages: "List[str]" = []

def pytest_report_collectionfinish(self, config: "Config") -> "List[str]":
lines = []
if self._messages:
lines += self._messages
def pytest_configure(config: "Config") -> None:
"""
Enable the plugins we need.
"""
if config.option.splits and config.option.group:
config.pluginmanager.register(PytestSplitPlugin(config), "pytestsplitplugin")

if hasattr(self, "_suite"):
lines.append(
f"Running group {self._group.index}/{self._suite.splits}"
f" ({self._group.num_tests}/{self._suite.num_tests}) tests"
)
if config.option.store_durations:
config.pluginmanager.register(PytestSplitCachePlugin(config), "pytestsplitcacheplugin")

prefix = "[pytest-split]"
lines = [f"{prefix} {m}" for m in lines]

return lines

def pytest_collection_modifyitems(
self, config: "Config", items: "List[nodes.Item]"
) -> None:
splits = config.option.splits
group = config.option.group
store_durations = config.option.store_durations
durations_report_path = config.option.durations_path

if store_durations:
if any((group, splits)):
self._messages.append(
"Not splitting tests because we are storing durations"
)
return None

if not group and not splits:
# don't split unless explicitly requested
return None

if not os.path.isfile(durations_report_path):
self._messages.append(
"Not splitting tests because the durations_report is missing"
)
return None

with open(durations_report_path) as f:
stored_durations = OrderedDict(json.load(f))
class Base:
def __init__(self, config: "Config") -> None:
"""
Load durations and set up a terminal writer.
start_idx, end_idx = _calculate_suite_start_and_end_idx(
splits, group, items, stored_durations
)
This logic is shared for both the split- and cache plugin.
"""
self.config = config
self.writer = create_terminal_writer(self.config)

self._suite = TestSuite(splits, len(items))
self._group = TestGroup(group, end_idx - start_idx)
items[:] = items[start_idx:end_idx]
try:
with open(config.option.durations_path, "r") as f:
self.cached_durations = json.loads(f.read())
except FileNotFoundError:
self.cached_durations = {}

# This code provides backwards compatibility after we switched
# from saving durations in a list-of-lists to a dict format
# Remove this when bumping to v1
if isinstance(self.cached_durations, list):
self.cached_durations = {test_name: duration for test_name, duration in self.cached_durations}

def pytest_configure(config: "Config") -> None:
config.pluginmanager.register(SplitPlugin())

class PytestSplitPlugin(Base):
def __init__(self, config: "Config"):
super().__init__(config)

self._messages: "List[str]" = []

if not self.cached_durations:
message = self.writer.markup(
"\n[pytest-split] No test durations found. Pytest-split will "
"split tests evenly when no durations are found. "
"\n[pytest-split] You can expect better results in consequent runs, "
"when test timings have been documented.\n"
)
self.writer.line(message)

@hookimpl(tryfirst=True)
def pytest_collection_modifyitems(self, config: "Config", items: "List[nodes.Item]") -> None:
"""
Collect and select the tests we want to run, and deselect the rest.
"""
splits: int = config.option.splits
group: int = config.option.group

selected_tests, deselected_tests = self._split_tests(splits, group, items, self.cached_durations)

items[:] = selected_tests
config.hook.pytest_deselected(items=deselected_tests)

self.writer.line(self.writer.markup(f"\n\n[pytest-split] Running group {group}/{splits}\n"))
return None

@staticmethod
def _split_tests(
splits: int,
group: int,
items: "List[nodes.Item]",
stored_durations: dict,
) -> "Tuple[list, list]":
"""
Split tests into groups by runtime.
:param splits: How many groups we're splitting in.
:param group: Which group this run represents.
:param items: Test items passed down by Pytest.
:param stored_durations: Our cached test runtimes.
:return:
Tuple of two lists.
The first list represents the tests we want to run,
while the other represents the tests we want to deselect.
"""
# Filtering down durations to relevant ones ensures the avg isn't skewed by irrelevant data
test_ids = [item.nodeid for item in items]
durations = {k: v for k, v in stored_durations.items() if k in test_ids}

if durations:
avg_duration_per_test = sum(durations.values()) / len(durations)
else:
# If there are no durations, give every test the same arbitrary value
avg_duration_per_test = 1

tests_and_durations = {item: durations.get(item.nodeid, avg_duration_per_test) for item in items}
time_per_group = sum(tests_and_durations.values()) / splits
selected, deselected = [], []

for _group in range(1, splits + 1):
group_tests, group_runtime = [], 0

for item in dict(tests_and_durations):
if group_runtime > time_per_group:
break

group_tests.append(item)
group_runtime += tests_and_durations.pop(item)

if _group == group:
selected = group_tests
else:
deselected.extend(group_tests)

return selected, deselected


class PytestSplitCachePlugin(Base):
"""
The cache plugin writes durations to our durations file.
"""

def pytest_sessionfinish(self) -> None:
"""
Method is called by Pytest after the test-suite has run.
https://github.com/pytest-dev/pytest/blob/main/src/_pytest/main.py#L308
"""
terminal_reporter = self.config.pluginmanager.get_plugin("terminalreporter")
test_durations = {}

def pytest_sessionfinish(session: "Session") -> None:
if session.config.option.store_durations:
report_path = session.config.option.durations_path
terminal_reporter = session.config.pluginmanager.get_plugin("terminalreporter")
durations: dict = defaultdict(float)
for test_reports in terminal_reporter.stats.values():
for test_report in test_reports:
if hasattr(test_report, "duration"):
stage = getattr(test_report, "when", "")
duration = test_report.duration
# These ifs be removed after this is solved:
# https://github.com/spulec/freezegun/issues/286
if duration < 0:
if isinstance(test_report, TestReport):

# These ifs be removed after this is solved: # https://github.com/spulec/freezegun/issues/286
if test_report.duration < 0:
continue
if (
stage in ("teardown", "setup")
and duration > STORE_DURATIONS_SETUP_AND_TEARDOWN_THRESHOLD
test_report.when in ("teardown", "setup")
and test_report.duration > STORE_DURATIONS_SETUP_AND_TEARDOWN_THRESHOLD
):
# Ignore not legit teardown durations
continue
durations[test_report.nodeid] += test_report.duration

with open(report_path, "w") as f:
f.write(json.dumps(list(durations.items()), indent=2))
# Add test durations to map
if test_report.nodeid not in test_durations:
test_durations[test_report.nodeid] = 0
test_durations[test_report.nodeid] += test_report.duration

terminal_writer = create_terminal_writer(session.config)
message = terminal_writer.markup(
" Stored test durations in {}\n".format(report_path)
)
terminal_reporter.write(message)
# Update the full cached-durations object
for k, v in test_durations.items():
self.cached_durations[k] = v

# Save durations
with open(self.config.option.durations_path, "w") as f:
json.dump(self.cached_durations, f)

def _calculate_suite_start_and_end_idx(
splits: int, group: int, items: "List[nodes.Item]", stored_durations: OrderedDict
) -> "Tuple[int, int]":
item_node_ids = [item.nodeid for item in items]
stored_durations = OrderedDict(
{k: v for k, v in stored_durations.items() if k in item_node_ids}
)
avg_duration_per_test = sum(stored_durations.values()) / len(stored_durations)

durations = OrderedDict()
for node_id in item_node_ids:
durations[node_id] = stored_durations.get(node_id, avg_duration_per_test)

time_per_group = sum(durations.values()) / splits
start_time = time_per_group * (group - 1)
end_time = time_per_group + start_time
start_idx = end_idx = duration_rolling_sum = 0

for idx, duration in enumerate(durations.values()):
duration_rolling_sum += duration
if group != 1 and not start_idx and duration_rolling_sum > start_time:
start_idx = idx
if group == splits:
break
elif group != splits and not end_idx and duration_rolling_sum > end_time:
end_idx = idx
break
if not end_idx:
end_idx = len(items)

return start_idx, end_idx
message = self.writer.markup(
"\n\n[pytest-split] Stored test durations in {}".format(self.config.option.durations_path)
)
self.writer.line(message)
Loading

0 comments on commit beed24d

Please sign in to comment.