From 7f922fa66461bb7dcac80abd68820a64846ada74 Mon Sep 17 00:00:00 2001 From: Max Inden Date: Thu, 29 Feb 2024 13:47:57 +0100 Subject: [PATCH] add --must-include flag (#355) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add --must-include flag Introduce new command line flag `--must-include`. When running the QUIC Interop Runner in CI of a single implementation, one is only interested in test pairs including that particular implementation. As an example, take a world with 3 QUIC implementations, quic-go, ngtcp2 and neqo. On the neqo CI one is not interested in the client-server pair quic-go->ngtcp2. With this commit one can specify `--must-include neqo` on neqo's CI. ``` $ python3 run.py --test handshake --must-include neqo Saving logs to logs_2024-02-24T17:39:08. Server: neqo. Client: quic-go. Running test case: handshake Server: neqo. Client: ngtcp2. Running test case: handshake Server: quic-go. Client: neqo. Running test case: handshake Server: ngtcp2. Client: neqo. Running test case: handshake Server: neqo. Client: neqo. Running test case: handshake Run took 0:01:50.070133 +---------+------+---------+--------+ | | neqo | quic-go | ngtcp2 | +---------+------+---------+--------+ | quic-go | ✓(H) | | | | | ?() | | | | | ✕() | | | +---------+------+---------+--------+ | ngtcp2 | ✓(H) | | | | | ?() | | | | | ✕() | | | +---------+------+---------+--------+ | neqo | ✓(H) | ✓(H) | ✓(H) | | | ?() | ?() | ?() | | | ✕() | ✕() | ✕() | +---------+------+---------+--------+ ``` * yapf fmt --- interop.py | 223 +++++++++++++++++++++++++++-------------------------- run.py | 24 +++++- 2 files changed, 137 insertions(+), 110 deletions(-) diff --git a/interop.py b/interop.py index 024e64a1..b610b00c 100644 --- a/interop.py +++ b/interop.py @@ -44,8 +44,7 @@ class InteropRunner: measurement_results = {} compliant = {} _implementations = {} - _servers = [] - _clients = [] + _client_server_pairs = [] _tests = [] _measurements = [] _output = "" @@ -55,8 +54,7 @@ class InteropRunner: def __init__( self, implementations: dict, - servers: List[str], - clients: List[str], + client_server_pairs: List[Tuple[str, str]], tests: List[testcases.TestCase], measurements: List[testcases.Measurement], output: str, @@ -75,8 +73,7 @@ def __init__( self._start_time = datetime.now() self._tests = tests self._measurements = measurements - self._servers = servers - self._clients = clients + self._client_server_pairs = client_server_pairs self._implementations = implementations self._output = output self._log_dir = log_dir @@ -86,16 +83,15 @@ def __init__( if os.path.exists(self._log_dir): sys.exit("Log dir " + self._log_dir + " already exists.") logging.info("Saving logs to %s.", self._log_dir) - for server in servers: - self.test_results[server] = {} - self.measurement_results[server] = {} - for client in clients: - self.test_results[server][client] = {} - for test in self._tests: - self.test_results[server][client][test] = {} - self.measurement_results[server][client] = {} - for measurement in measurements: - self.measurement_results[server][client][measurement] = {} + for client, server in client_server_pairs: + for test in self._tests: + self.test_results.setdefault(server, {}).setdefault( + client, {} + ).setdefault(test, {}) + for measurement in measurements: + self.measurement_results.setdefault(server, {}).setdefault( + client, {} + ).setdefault(measurement, {}) def _is_unsupported(self, lines: List[str]) -> bool: return any("exited with code 127" in str(line) for line in lines) or any( @@ -193,15 +189,22 @@ def get_letters(result): t = prettytable.PrettyTable() t.hrules = prettytable.ALL t.vrules = prettytable.ALL - t.field_names = [""] + [name for name in self._servers] - for client in self._clients: + rows = {} + columns = {} + for client, server in self._client_server_pairs: + columns[server] = {} + row = rows.setdefault(client, {}) + cell = self.test_results[server][client] + res = colored(get_letters(TestResult.SUCCEEDED), "green") + "\n" + res += colored(get_letters(TestResult.UNSUPPORTED), "grey") + "\n" + res += colored(get_letters(TestResult.FAILED), "red") + row[server] = res + + t.field_names = [""] + [column for column, _ in columns.items()] + for client, results in rows.items(): row = [client] - for server in self._servers: - cell = self.test_results[server][client] - res = colored(get_letters(TestResult.SUCCEEDED), "green") + "\n" - res += colored(get_letters(TestResult.UNSUPPORTED), "grey") + "\n" - res += colored(get_letters(TestResult.FAILED), "red") - row += [res] + for server, _ in columns.items(): + row += [results.setdefault(server, "")] t.add_row(row) print(t) @@ -209,44 +212,50 @@ def get_letters(result): t = prettytable.PrettyTable() t.hrules = prettytable.ALL t.vrules = prettytable.ALL - t.field_names = [""] + [name for name in self._servers] - for client in self._clients: - row = [client] - for server in self._servers: - cell = self.measurement_results[server][client] - results = [] - for measurement in self._measurements: - res = cell[measurement] - if not hasattr(res, "result"): - continue - if res.result == TestResult.SUCCEEDED: - results.append( - colored( - measurement.abbreviation() + ": " + res.details, - "green", - ) + t.field_names = [""] + rows = {} + columns = {} + for client, server in self._client_server_pairs: + columns[server] = {} + row = rows.setdefault(client, {}) + cell = self.measurement_results[server][client] + results = [] + for measurement in self._measurements: + res = cell[measurement] + if not hasattr(res, "result"): + continue + if res.result == TestResult.SUCCEEDED: + results.append( + colored( + measurement.abbreviation() + ": " + res.details, + "green", ) - elif res.result == TestResult.UNSUPPORTED: - results.append(colored(measurement.abbreviation(), "grey")) - elif res.result == TestResult.FAILED: - results.append(colored(measurement.abbreviation(), "red")) - row += ["\n".join(results)] + ) + elif res.result == TestResult.UNSUPPORTED: + results.append(colored(measurement.abbreviation(), "grey")) + elif res.result == TestResult.FAILED: + results.append(colored(measurement.abbreviation(), "red")) + row[server] += "\n".join(results) + t.field_names = [""] + [column for column, _ in columns.items()] + for client, results in rows.items(): + row = [client] + for server, _ in columns.items(): + row += [results.setdefault(server, "")] t.add_row(row) print(t) def _export_results(self): if not self._output: return + clients = list(set(server for client, server in self._client_server_pairs)) + servers = list(set(server for client, server in self._client_server_pairs)) out = { "start_time": self._start_time.timestamp(), "end_time": datetime.now().timestamp(), "log_dir": self._log_dir, - "servers": [name for name in self._servers], - "clients": [name for name in self._clients], - "urls": { - x: self._implementations[x]["url"] - for x in self._servers + self._clients - }, + "servers": servers, + "clients": clients, + "urls": {x: self._implementations[x]["url"] for x in clients + servers}, "tests": { x.abbreviation(): { "name": x.name(), @@ -260,36 +269,35 @@ def _export_results(self): "measurements": [], } - for client in self._clients: - for server in self._servers: - results = [] - for test in self._tests: - r = None - if hasattr(self.test_results[server][client][test], "value"): - r = self.test_results[server][client][test].value - results.append( - { - "abbr": test.abbreviation(), - "name": test.name(), # TODO: remove - "result": r, - } - ) - out["results"].append(results) - - measurements = [] - for measurement in self._measurements: - res = self.measurement_results[server][client][measurement] - if not hasattr(res, "result"): - continue - measurements.append( - { - "name": measurement.name(), # TODO: remove - "abbr": measurement.abbreviation(), - "result": res.result.value, - "details": res.details, - } - ) - out["measurements"].append(measurements) + for client, server in self._client_server_pairs: + results = [] + for test in self._tests: + r = None + if hasattr(self.test_results[server][client][test], "value"): + r = self.test_results[server][client][test].value + results.append( + { + "abbr": test.abbreviation(), + "name": test.name(), # TODO: remove + "result": r, + } + ) + out["results"].append(results) + + measurements = [] + for measurement in self._measurements: + res = self.measurement_results[server][client][measurement] + if not hasattr(res, "result"): + continue + measurements.append( + { + "name": measurement.name(), # TODO: remove + "abbr": measurement.abbreviation(), + "result": res.result.value, + "details": res.details, + } + ) + out["measurements"].append(measurements) f = open(self._output, "w") json.dump(out, f) @@ -486,33 +494,32 @@ def run(self): """run the interop test suite and output the table""" nr_failed = 0 - for server in self._servers: - for client in self._clients: - logging.debug( - "Running with server %s (%s) and client %s (%s)", - server, - self._implementations[server]["image"], - client, - self._implementations[client]["image"], - ) - if not ( - self._check_impl_is_compliant(server) - and self._check_impl_is_compliant(client) - ): - logging.info("Not compliant, skipping") - continue - - # run the test cases - for testcase in self._tests: - status = self._run_testcase(server, client, testcase) - self.test_results[server][client][testcase] = status - if status == TestResult.FAILED: - nr_failed += 1 - - # run the measurements - for measurement in self._measurements: - res = self._run_measurement(server, client, measurement) - self.measurement_results[server][client][measurement] = res + for client, server in self._client_server_pairs: + logging.debug( + "Running with server %s (%s) and client %s (%s)", + server, + self._implementations[server]["image"], + client, + self._implementations[client]["image"], + ) + if not ( + self._check_impl_is_compliant(server) + and self._check_impl_is_compliant(client) + ): + logging.info("Not compliant, skipping") + continue + + # run the test cases + for testcase in self._tests: + status = self._run_testcase(server, client, testcase) + self.test_results[server][client][testcase] = status + if status == TestResult.FAILED: + nr_failed += 1 + + # run the measurements + for measurement in self._measurements: + res = self._run_measurement(server, client, measurement) + self.measurement_results[server][client][measurement] = res self._print_results() self._export_results() diff --git a/run.py b/run.py index fbd9515b..0cf3b91e 100755 --- a/run.py +++ b/run.py @@ -65,6 +65,11 @@ def get_args(): parser.add_argument( "-j", "--json", help="output the matrix to file in json format" ) + parser.add_argument( + "-i", + "--must-include", + help="implementation that must be included", + ) return parser.parse_args() replace_arg = get_args().replace @@ -88,6 +93,18 @@ def get_impls(arg, availableImpls, role) -> List[str]: impls.append(s) return impls + def get_impl_pairs(clients, servers, must_include) -> List[Tuple[str, str]]: + impls = [] + for client in clients: + for server in servers: + if ( + must_include is None + or client == must_include + or server == must_include + ): + impls.append((client, server)) + return impls + def get_tests_and_measurements( arg, ) -> Tuple[List[testcases.TestCase], List[testcases.TestCase]]: @@ -124,8 +141,11 @@ def get_tests_and_measurements( t = get_tests_and_measurements(get_args().test) return InteropRunner( implementations=implementations, - servers=get_impls(get_args().server, server_implementations, "Server"), - clients=get_impls(get_args().client, client_implementations, "Client"), + client_server_pairs=get_impl_pairs( + get_impls(get_args().client, client_implementations, "Client"), + get_impls(get_args().server, server_implementations, "Server"), + get_args().must_include, + ), tests=t[0], measurements=t[1], output=get_args().json,