Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sondrelg committed May 30, 2021
2 parents 869700c + fc2badf commit 7161e05
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 219 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[tool.black]
line-length = 88
line-length = 120
include = '\.pyi?$'
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
max-line-length = 88
max-line-length = 120
exclude =
tests/*
ignore = ANN101, W503
2 changes: 1 addition & 1 deletion src/pytest_split/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = '0.1.5'
from ._version import version as __version__

__all__ = ("__version__",)
269 changes: 168 additions & 101 deletions src/pytest_split/plugin.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import json
import os
from collections import defaultdict, OrderedDict
from collections import OrderedDict
from typing import TYPE_CHECKING, Tuple, Generator
from warnings import warn

import _pytest
import pytest
from _pytest.config import create_terminal_writer
from _pytest.config.argparsing import Parser
from _pytest.main import Session
from _pytest.config import create_terminal_writer, hookimpl

if TYPE_CHECKING:
from typing import List

from _pytest import nodes
from _pytest.config import Config
from _pytest.config.argparsing import Parser

# Ugly hacks for freezegun compatibility:
# https://github.com/spulec/freezegun/issues/286
STORE_DURATIONS_SETUP_AND_TEARDOWN_THRESHOLD = 60 * 10 # seconds
CACHE_PATH = ".pytest_cache/v/cache/pytest_split"

@pytest.hookimpl()
def pytest_addoption(parser: Parser) -> None:

def pytest_addoption(parser: "Parser") -> None:
"""
Declare plugin options.
"""
Expand All @@ -31,15 +27,6 @@ def pytest_addoption(parser: Parser) -> None:
"Run first the whole suite with --store-durations to save information "
"about test execution times"
)
group.addoption(
"--durations-path",
dest="durations_path",
help=(
"Path to the file in which durations are (to be) stored. "
f"By default, durations will be written to {CACHE_PATH}"
),
default=os.path.join(os.getcwd(), CACHE_PATH),
)
group.addoption(
"--splits",
dest="splits",
Expand All @@ -52,21 +39,33 @@ def pytest_addoption(parser: Parser) -> None:
type=int,
help="The group of tests that should be executed (first one is 1)",
)
group.addoption(
"--durations-path",
dest="durations_path",
help=(
"Path to the file in which durations are (to be) stored. "
f"By default, durations will be written to {CACHE_PATH}"
),
default="",
)


@pytest.hookimpl(trylast=True)
def pytest_configure(config: "Config") -> None:
"""
Configure plugin.
Enable the plugin if appropriate arguments are passed.
"""
if (config.option.splits and not config.option.group) or (
config.option.group and not config.option.split
):
if config.option.splits and not config.option.group:
warn(
"It looks like you passed an argument to run pytest with pytest-split, "
"but both the `splits` and `group` arguments are required for pytest-split to run"
"Both the `splits` and `group` arguments are required for pytest-split "
"to run. Remove the `splits` argument or add a `groups` argument."
)
if config.option.splits and config.option.group:
elif config.option.group and not config.option.splits:
warn(
"Both the `splits` and `group` arguments are required for pytest-split "
"to run. Remove the `groups` argument or add a `splits` argument."
)
elif config.option.splits and config.option.group:
# Register plugin to run only if we received a splits and group arg
config.pluginmanager.register(PytestSplitPlugin(config), "pytestsplitplugin")


Expand All @@ -77,15 +76,24 @@ def __init__(self, config: "Config") -> None:
"""
Load cache and configure plugin.
"""
self.cached_durations = dict(config.cache.get(self.cache_file, {}))
self.config = config
if config.option.durations_path:
with open(config.option.durations_path, "r") as f:
self.cached_durations = json.loads(f.read())
else:
self.cached_durations = dict(config.cache.get(self.cache_file, {}))

self.writer = create_terminal_writer(self.config)
if not self.cached_durations:
warn(
self.writer.line()
self.writer.line(
"No test durations found. Pytest-split will "
"split tests evenly when no durations are found, "
"so you should expect better results in following "
"runs when test timings have been documented."
"split tests evenly when no durations are found. "
"\nYou can expect better results in consequent runs, "
"when test timings have been documented."
)

@hookimpl(hookwrapper=True, tryfirst=True)
def pytest_collection_modifyitems(self, config: "Config", items: "List[nodes.Item]") -> Generator[None, None, None]:
"""
Instruct Pytest to run the tests we've selected.
Expand All @@ -98,82 +106,141 @@ def pytest_collection_modifyitems(self, config: "Config", items: "List[nodes.Ite
# Load plugin arguments
splits: int = config.option.splits
group: int = config.option.group
durations_report_path: str = config.option.durations_path

total_tests_count = len(items)
stored_durations = OrderedDict(config.cache.get(self.cache_file, {}))

start_idx, end_idx = self._calculate_suite_start_and_end_idx(splits, group, items, stored_durations)
items[:] = items[start_idx:end_idx]
selected_tests, deselected_tests = self._split_tests(splits, group, items, self.cached_durations)

writer = create_terminal_writer(config)
message = writer.markup(
" Running group {}/{} ({}/{} tests)\n".format(
group, splits, len(items), total_tests_count
)
)
writer.line(message)
items[:] = selected_tests
config.hook.pytest_deselected(items=deselected_tests)

def pytest_sessionfinish(self, session: "Session") -> None:
if session.config.option.store_durations:
report_path = session.config.option.durations_path
terminal_reporter = session.config.pluginmanager.get_plugin(
"terminalreporter"
)
durations = defaultdict(float)
for test_reports in terminal_reporter.stats.values():
for test_report in test_reports:
if hasattr(test_report, "duration"):
stage = getattr(test_report, "when", "")
duration = test_report.duration
# These ifs be removed after this is solved:
# https://github.com/spulec/freezegun/issues/286
if duration < 0:
continue
if (
stage in ("teardown", "setup")
and duration > STORE_DURATIONS_SETUP_AND_TEARDOWN_THRESHOLD
):
# Ignore not legit teardown durations
continue
durations[test_report.nodeid] += test_report.duration

with open(report_path, "w") as f:
f.write(json.dumps(list(durations.items()), indent=2))

terminal_writer = create_terminal_writer(session.config)
message = terminal_writer.markup(
" Stored test durations in {}\n".format(report_path)
)
terminal_reporter.write(message)
message = self.writer.markup(
" Running group {}/{} ({}/{} tests)\n".format(group, splits, len(items), total_tests_count)
)
self.writer.line(message)
yield

@staticmethod
def _calculate_suite_start_and_end_idx(splits: int, group: int, items: "List[nodes.Item]", stored_durations: OrderedDict) -> Tuple[int, int]:
item_node_ids = [item.nodeid for item in items]
stored_durations = {
k: v for k, v in stored_durations.items() if k in item_node_ids
}
avg_duration_per_test = sum(stored_durations.values()) / len(stored_durations)

durations = OrderedDict()
for node_id in item_node_ids:
durations[node_id] = stored_durations.get(node_id, avg_duration_per_test)

time_per_group = sum(durations.values()) / splits
start_time = time_per_group * (group - 1)
end_time = time_per_group + start_time
start_idx = end_idx = duration_rolling_sum = 0

for idx, duration in enumerate(durations.values()):
duration_rolling_sum += duration
if group != 1 and not start_idx and duration_rolling_sum > start_time:
start_idx = idx
if group == splits:
def _split_tests(
splits: int,
group: int,
items: "List[nodes.Item]",
stored_durations: OrderedDict,
) -> Tuple[int, int]:
"""
Split tests by runtime.
The splitting logic is very simple. We find out what our threshold runtime
is per group, then start adding tests (slowest tests ordered first) until we
get close to the threshold runtime per group. We then reverse the ordering and
add the fastest tests available until we go just *beyond* the threshold.
The choice we're making is to overload the first groups a little bit. The reason
this reasonable is that ci-providers like GHA will usually spin up the first
groups first, meaning that if you had a perfect test split, the first groups
would still finish first. The *overloading* is also minimal, so shouldn't
matter in most cases.
After assigning tests to each group we select the group we're in
and deselect all remaining tests.
:param splits: How many groups we're splitting in.
:param group: Which group this run represents.
:param items: Test items passed down by Pytest.
:param stored_durations: Our cached test runtimes.
:return:
Tuple of two lists.
The first list represents the tests we want to run,
while the other represents the tests we want to deselect.
"""
# Filter down stored durations to only relevant tests durations -
# this way the average duration per test is calculated on relevant tests only
test_names = [item.nodeid for item in items]
durations = {k: v for k, v in stored_durations.items() if k in test_names}

# Get the average duration for each test not in the cache
if durations:
avg_duration_per_test = sum(durations.values()) / len(durations)
else:
# If there are no durations, we give every test the same assumed arbitrary value
avg_duration_per_test = 1

# Create a dict of test-name: runtime
tests_and_durations = {item: durations.get(item.nodeid, avg_duration_per_test) for item in items}

# Set the threshold runtime value per group
time_per_group = sum(tests_and_durations.values()) / splits

# Order the dict so the slowest tests appear first
sorted_tests_and_durations = OrderedDict(sorted(tests_and_durations.items(), key=lambda x: x[1], reverse=True))

selected, deselected = [], []

# Finally, we split tests equally between groups
for _group in range(1, splits + 1):
group_tests, group_runtime = [], 0

# Add slow tests up until *one more test would cross the threshold*
for item in OrderedDict(sorted_tests_and_durations):
if group_runtime + sorted_tests_and_durations[item] > time_per_group:
break
group_tests.append(item)
group_runtime += sorted_tests_and_durations.pop(item)

# Add fast tests until *we do cross the threshold*
for item in OrderedDict(sorted(sorted_tests_and_durations.items(), key=lambda x: x[1], reverse=False)):
if group_runtime > time_per_group:
break
elif group != splits and not end_idx and duration_rolling_sum > end_time:
end_idx = idx
break
if not end_idx:
end_idx = len(items)
group_tests.append(item)
group_runtime += sorted_tests_and_durations.pop(item)

return start_idx, end_idx
if _group == group:
selected = group_tests
else:
deselected.extend(group_tests)

return selected, deselected

def pytest_sessionfinish(self) -> None:
"""
Write test runtimes to cache.
Method is called by Pytest after the test-suite has run.
https://github.com/pytest-dev/pytest/blob/main/src/_pytest/main.py#L308
"""
terminal_reporter = self.config.pluginmanager.get_plugin("terminalreporter")
test_durations = {}

for test_reports in terminal_reporter.stats.values():
for test_report in test_reports:
if hasattr(test_report, "duration"):
# These ifs be removed after this is solved:
# https://github.com/spulec/freezegun/issues/286
if test_report.duration < 0:
continue
if (
getattr(test_report, "when", "") in ("teardown", "setup")
and test_report.duration > STORE_DURATIONS_SETUP_AND_TEARDOWN_THRESHOLD
):
# Ignore not legit teardown durations
continue

# Add test durations to map
if test_report.nodeid not in test_durations:
test_durations[test_report.nodeid] = 0
test_durations[test_report.nodeid] += test_report.duration

# Update the full cached-durations object
for k, v in test_durations.items():
self.cached_durations[k] = v

# Save to cache
self.config.cache.set(self.cache_file, self.cached_durations)

# Save to custom file if needed
if self.config.option.durations_path:
with open(self.config.option.durations_path, "w") as f:
f.write(json.dumps(self.cached_durations))

message = self.writer.markup(" Stored test durations in {}\n".format(self.config.option.durations_path))
self.writer.line(message)
Loading

0 comments on commit 7161e05

Please sign in to comment.