forked from google/grr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conftest.py
199 lines (152 loc) · 6 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#!/usr/bin/env python
"""A module that configures the behaviour of pytest runner."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import threading
import traceback
from absl import flags
import pytest
from grr_response_core.lib.util import compatibility
from grr.test_lib import testing_startup
FLAGS = flags.FLAGS
SKIP_BENCHMARK = pytest.mark.skip(
reason="benchmark tests are executed only with --benchmark flag")
test_args = None
def pytest_cmdline_preparse(config, args):
"""A pytest hook that is called during command-line argument parsing."""
del config # Unused.
try:
separator = args.index("--")
except ValueError:
separator = len(args)
global test_args
test_args = args[separator + 1:]
del args[separator:]
def pytest_cmdline_main(config):
"""A pytest hook that is called when the main function is executed."""
if "PYTEST_XDIST_WORKER" in os.environ:
# If ran concurrently using pytest-xdist (`-n` cli flag), mainargv is the
# result of the execution of pytest_cmdline_main in the main process.
sys.argv = config.workerinput["mainargv"]
else:
# TODO: `sys.argv` on Python 2 uses `bytes` to represent passed
# arguments.
sys.argv = [compatibility.NativeStr("pytest")] + test_args
last_module = None
def pytest_runtest_setup(item):
"""A pytest hook that is called before each test item is executed."""
# We need to re-initialize flags (and hence also testing setup) because
# various modules might have various flags defined.
global last_module
if last_module != item.module:
FLAGS(sys.argv)
testing_startup.TestInit()
last_module = item.module
def pytest_addoption(parser):
"""A pytest hook that is called during the argument parser initialization."""
parser.addoption(
"-B",
"--benchmark",
dest="benchmark",
default=False,
action="store_true",
help="run tests marked as benchmarks")
parser.addoption(
"--full_thread_trace",
action="store_true",
default=False,
help="Include a full stacktrace for all thread in case of a thread leak.",
)
def pytest_collection_modifyitems(session, config, items):
"""A pytest hook that is called when the test item collection is done."""
del session # Unused.
benchmark = config.getoption("benchmark")
if benchmark:
return
for item in items:
for marker in item.iter_markers():
if marker.name == "benchmark":
item.add_marker(SKIP_BENCHMARK)
def _generate_full_thread_trace():
"""Generates a full stack trace for all currently running threads."""
threads = threading.enumerate()
res = "Stacktrace for:\n"
for thread in threads:
res += "%s (id %d)\n" % (thread.name, thread.ident)
res += "\n"
frames = sys._current_frames() # pylint: disable=protected-access
for thread_id, stack in frames.items():
res += "Thread ID: %s\n" % thread_id
for filename, lineno, name, line in traceback.extract_stack(stack):
res += "File: '%s', line %d, in %s\n" % (filename, lineno, name)
if line:
res += " %s\n" % (line.strip())
return res
last_test_name = None
known_leaks = []
@pytest.fixture(scope="function", autouse=True)
def thread_leak_check(request):
"""Makes sure that no threads are left running by any test."""
global last_test_name
threads = threading.enumerate()
# Quoting Python docs (https://docs.python.org/3/library/threading.html):
# threading.current_thread():
# Return the current Thread object, corresponding to the caller's thread
# of control. If the caller's thread of control was not created through
# the threading module, a dummy thread object with limited functionality
# is returned.
#
# Quoting Python source
# (https://github.com/python/cpython/blob/2a16eea71f56c2d8f38c295c8ce71a9a9a140aff/Lib/threading.py#L1269):
# Dummy thread class to represent threads not started here.
# These aren't garbage collected when they die, nor can they be waited for.
# If they invoke anything in threading.py that calls current_thread(), they
# leave an entry in the _active dict forever after.
# Their purpose is to return *something* from current_thread().
# They are marked as daemon threads so we won't wait for them
# when we exit (conform previous semantics).
#
# See
# https://stackoverflow.com/questions/55778365/what-is-dummy-in-threading-current-thread
# for additional context.
#
# Dummy threads are named "Dummy-*" and are never deleted, since it's
# impossible to detect the termination of alien threads, hence we have to
# ignore them.
thread_names = [
thread.name for thread in threads if not thread.name.startswith("Dummy-")
]
allowed_thread_names = [
"MainThread",
# We start one thread per connector and let them run since there is a lot
# of overhead involved.
"ApiRegressionHttpConnectorV1",
"ApiRegressionHttpConnectorV2",
# Selenium takes long to set up, we clean up using an atexit handler.
"SeleniumServerThread",
# All these threads are constructed in setUpClass and destroyed in
# tearDownClass so they are not real leaks.
"api_integration_server",
"ApiSslServerTest",
"GRRHTTPServerTestThread",
"SharedMemDBTestThread",
]
# Remove up to one instance of each allowed thread name.
for allowed_name in allowed_thread_names + known_leaks:
if allowed_name in thread_names:
thread_names.remove(allowed_name)
current_test_name = request.node.name
if thread_names:
# Store any leaks so we only alert once about each leak.
known_leaks.extend(thread_names)
error_msg = ("Detected unexpected thread(s): %s. "
"Last test was %s, next test is %s." %
(thread_names, last_test_name, current_test_name))
if request.config.getoption("full_thread_trace"):
error_msg += "\n\n" + _generate_full_thread_trace()
raise RuntimeError(error_msg)
last_test_name = current_test_name