Skip to content

Commit

Permalink
add --must-include flag (#355)
Browse files Browse the repository at this point in the history
* feat: add --must-include flag

Introduce new command line flag `--must-include`.

When running the QUIC Interop Runner in CI of a single implementation, one is
only interested in test pairs including that particular implementation. As an
example, take a world with 3 QUIC implementations, quic-go, ngtcp2 and neqo.
On the neqo CI one is not interested in the client-server pair quic-go->ngtcp2.

With this commit one can specify `--must-include neqo` on neqo's CI.

```
$ python3 run.py --test handshake --must-include neqo

Saving logs to logs_2024-02-24T17:39:08.
Server: neqo. Client: quic-go. Running test case: handshake
Server: neqo. Client: ngtcp2. Running test case: handshake
Server: quic-go. Client: neqo. Running test case: handshake
Server: ngtcp2. Client: neqo. Running test case: handshake
Server: neqo. Client: neqo. Running test case: handshake
Run took 0:01:50.070133
+---------+------+---------+--------+
|         | neqo | quic-go | ngtcp2 |
+---------+------+---------+--------+
| quic-go | ✓(H) |         |        |
|         | ?()  |         |        |
|         | ✕()  |         |        |
+---------+------+---------+--------+
|  ngtcp2 | ✓(H) |         |        |
|         | ?()  |         |        |
|         | ✕()  |         |        |
+---------+------+---------+--------+
|   neqo  | ✓(H) |   ✓(H)  |  ✓(H)  |
|         | ?()  |   ?()   |  ?()   |
|         | ✕()  |   ✕()   |  ✕()   |
+---------+------+---------+--------+
```

* yapf fmt
  • Loading branch information
mxinden authored Feb 29, 2024
1 parent 92fae15 commit 7f922fa
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 110 deletions.
223 changes: 115 additions & 108 deletions interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ class InteropRunner:
measurement_results = {}
compliant = {}
_implementations = {}
_servers = []
_clients = []
_client_server_pairs = []
_tests = []
_measurements = []
_output = ""
Expand All @@ -55,8 +54,7 @@ class InteropRunner:
def __init__(
self,
implementations: dict,
servers: List[str],
clients: List[str],
client_server_pairs: List[Tuple[str, str]],
tests: List[testcases.TestCase],
measurements: List[testcases.Measurement],
output: str,
Expand All @@ -75,8 +73,7 @@ def __init__(
self._start_time = datetime.now()
self._tests = tests
self._measurements = measurements
self._servers = servers
self._clients = clients
self._client_server_pairs = client_server_pairs
self._implementations = implementations
self._output = output
self._log_dir = log_dir
Expand All @@ -86,16 +83,15 @@ def __init__(
if os.path.exists(self._log_dir):
sys.exit("Log dir " + self._log_dir + " already exists.")
logging.info("Saving logs to %s.", self._log_dir)
for server in servers:
self.test_results[server] = {}
self.measurement_results[server] = {}
for client in clients:
self.test_results[server][client] = {}
for test in self._tests:
self.test_results[server][client][test] = {}
self.measurement_results[server][client] = {}
for measurement in measurements:
self.measurement_results[server][client][measurement] = {}
for client, server in client_server_pairs:
for test in self._tests:
self.test_results.setdefault(server, {}).setdefault(
client, {}
).setdefault(test, {})
for measurement in measurements:
self.measurement_results.setdefault(server, {}).setdefault(
client, {}
).setdefault(measurement, {})

def _is_unsupported(self, lines: List[str]) -> bool:
return any("exited with code 127" in str(line) for line in lines) or any(
Expand Down Expand Up @@ -193,60 +189,73 @@ def get_letters(result):
t = prettytable.PrettyTable()
t.hrules = prettytable.ALL
t.vrules = prettytable.ALL
t.field_names = [""] + [name for name in self._servers]
for client in self._clients:
rows = {}
columns = {}
for client, server in self._client_server_pairs:
columns[server] = {}
row = rows.setdefault(client, {})
cell = self.test_results[server][client]
res = colored(get_letters(TestResult.SUCCEEDED), "green") + "\n"
res += colored(get_letters(TestResult.UNSUPPORTED), "grey") + "\n"
res += colored(get_letters(TestResult.FAILED), "red")
row[server] = res

t.field_names = [""] + [column for column, _ in columns.items()]
for client, results in rows.items():
row = [client]
for server in self._servers:
cell = self.test_results[server][client]
res = colored(get_letters(TestResult.SUCCEEDED), "green") + "\n"
res += colored(get_letters(TestResult.UNSUPPORTED), "grey") + "\n"
res += colored(get_letters(TestResult.FAILED), "red")
row += [res]
for server, _ in columns.items():
row += [results.setdefault(server, "")]
t.add_row(row)
print(t)

if len(self._measurements) > 0:
t = prettytable.PrettyTable()
t.hrules = prettytable.ALL
t.vrules = prettytable.ALL
t.field_names = [""] + [name for name in self._servers]
for client in self._clients:
row = [client]
for server in self._servers:
cell = self.measurement_results[server][client]
results = []
for measurement in self._measurements:
res = cell[measurement]
if not hasattr(res, "result"):
continue
if res.result == TestResult.SUCCEEDED:
results.append(
colored(
measurement.abbreviation() + ": " + res.details,
"green",
)
t.field_names = [""]
rows = {}
columns = {}
for client, server in self._client_server_pairs:
columns[server] = {}
row = rows.setdefault(client, {})
cell = self.measurement_results[server][client]
results = []
for measurement in self._measurements:
res = cell[measurement]
if not hasattr(res, "result"):
continue
if res.result == TestResult.SUCCEEDED:
results.append(
colored(
measurement.abbreviation() + ": " + res.details,
"green",
)
elif res.result == TestResult.UNSUPPORTED:
results.append(colored(measurement.abbreviation(), "grey"))
elif res.result == TestResult.FAILED:
results.append(colored(measurement.abbreviation(), "red"))
row += ["\n".join(results)]
)
elif res.result == TestResult.UNSUPPORTED:
results.append(colored(measurement.abbreviation(), "grey"))
elif res.result == TestResult.FAILED:
results.append(colored(measurement.abbreviation(), "red"))
row[server] += "\n".join(results)
t.field_names = [""] + [column for column, _ in columns.items()]
for client, results in rows.items():
row = [client]
for server, _ in columns.items():
row += [results.setdefault(server, "")]
t.add_row(row)
print(t)

def _export_results(self):
if not self._output:
return
clients = list(set(server for client, server in self._client_server_pairs))
servers = list(set(server for client, server in self._client_server_pairs))
out = {
"start_time": self._start_time.timestamp(),
"end_time": datetime.now().timestamp(),
"log_dir": self._log_dir,
"servers": [name for name in self._servers],
"clients": [name for name in self._clients],
"urls": {
x: self._implementations[x]["url"]
for x in self._servers + self._clients
},
"servers": servers,
"clients": clients,
"urls": {x: self._implementations[x]["url"] for x in clients + servers},
"tests": {
x.abbreviation(): {
"name": x.name(),
Expand All @@ -260,36 +269,35 @@ def _export_results(self):
"measurements": [],
}

for client in self._clients:
for server in self._servers:
results = []
for test in self._tests:
r = None
if hasattr(self.test_results[server][client][test], "value"):
r = self.test_results[server][client][test].value
results.append(
{
"abbr": test.abbreviation(),
"name": test.name(), # TODO: remove
"result": r,
}
)
out["results"].append(results)

measurements = []
for measurement in self._measurements:
res = self.measurement_results[server][client][measurement]
if not hasattr(res, "result"):
continue
measurements.append(
{
"name": measurement.name(), # TODO: remove
"abbr": measurement.abbreviation(),
"result": res.result.value,
"details": res.details,
}
)
out["measurements"].append(measurements)
for client, server in self._client_server_pairs:
results = []
for test in self._tests:
r = None
if hasattr(self.test_results[server][client][test], "value"):
r = self.test_results[server][client][test].value
results.append(
{
"abbr": test.abbreviation(),
"name": test.name(), # TODO: remove
"result": r,
}
)
out["results"].append(results)

measurements = []
for measurement in self._measurements:
res = self.measurement_results[server][client][measurement]
if not hasattr(res, "result"):
continue
measurements.append(
{
"name": measurement.name(), # TODO: remove
"abbr": measurement.abbreviation(),
"result": res.result.value,
"details": res.details,
}
)
out["measurements"].append(measurements)

f = open(self._output, "w")
json.dump(out, f)
Expand Down Expand Up @@ -486,33 +494,32 @@ def run(self):
"""run the interop test suite and output the table"""

nr_failed = 0
for server in self._servers:
for client in self._clients:
logging.debug(
"Running with server %s (%s) and client %s (%s)",
server,
self._implementations[server]["image"],
client,
self._implementations[client]["image"],
)
if not (
self._check_impl_is_compliant(server)
and self._check_impl_is_compliant(client)
):
logging.info("Not compliant, skipping")
continue

# run the test cases
for testcase in self._tests:
status = self._run_testcase(server, client, testcase)
self.test_results[server][client][testcase] = status
if status == TestResult.FAILED:
nr_failed += 1

# run the measurements
for measurement in self._measurements:
res = self._run_measurement(server, client, measurement)
self.measurement_results[server][client][measurement] = res
for client, server in self._client_server_pairs:
logging.debug(
"Running with server %s (%s) and client %s (%s)",
server,
self._implementations[server]["image"],
client,
self._implementations[client]["image"],
)
if not (
self._check_impl_is_compliant(server)
and self._check_impl_is_compliant(client)
):
logging.info("Not compliant, skipping")
continue

# run the test cases
for testcase in self._tests:
status = self._run_testcase(server, client, testcase)
self.test_results[server][client][testcase] = status
if status == TestResult.FAILED:
nr_failed += 1

# run the measurements
for measurement in self._measurements:
res = self._run_measurement(server, client, measurement)
self.measurement_results[server][client][measurement] = res

self._print_results()
self._export_results()
Expand Down
24 changes: 22 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def get_args():
parser.add_argument(
"-j", "--json", help="output the matrix to file in json format"
)
parser.add_argument(
"-i",
"--must-include",
help="implementation that must be included",
)
return parser.parse_args()

replace_arg = get_args().replace
Expand All @@ -88,6 +93,18 @@ def get_impls(arg, availableImpls, role) -> List[str]:
impls.append(s)
return impls

def get_impl_pairs(clients, servers, must_include) -> List[Tuple[str, str]]:
impls = []
for client in clients:
for server in servers:
if (
must_include is None
or client == must_include
or server == must_include
):
impls.append((client, server))
return impls

def get_tests_and_measurements(
arg,
) -> Tuple[List[testcases.TestCase], List[testcases.TestCase]]:
Expand Down Expand Up @@ -124,8 +141,11 @@ def get_tests_and_measurements(
t = get_tests_and_measurements(get_args().test)
return InteropRunner(
implementations=implementations,
servers=get_impls(get_args().server, server_implementations, "Server"),
clients=get_impls(get_args().client, client_implementations, "Client"),
client_server_pairs=get_impl_pairs(
get_impls(get_args().client, client_implementations, "Client"),
get_impls(get_args().server, server_implementations, "Server"),
get_args().must_include,
),
tests=t[0],
measurements=t[1],
output=get_args().json,
Expand Down

0 comments on commit 7f922fa

Please sign in to comment.