generated from VectorInstitute/aieng-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP collecting and checking the output
- Loading branch information
Showing
3 changed files
with
99 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,130 @@ | ||
import asyncio | ||
import logging | ||
import subprocess | ||
import time | ||
from asyncio.subprocess import PIPE, STDOUT | ||
|
||
logging.basicConfig() | ||
logging.basicConfig( | ||
format='%(asctime)s %(levelname)-8s %(message)s', | ||
level=logging.DEBUG, | ||
datefmt='%Y-%m-%d %H:%M:%S' | ||
) | ||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
def run_smoke_test( | ||
async def run_smoke_test( | ||
n_clients_to_start: int = 4, | ||
config_path: str = "tests/smoke_tests/config.yaml", | ||
dataset_path: str = "examples/datasets/mnist_data/", | ||
) -> None: | ||
# Start the server, divert the outputs to a server file | ||
logger.info("Starting server") | ||
|
||
ps = subprocess.Popen([ | ||
server_process = await asyncio.create_subprocess_exec( | ||
"nohup", | ||
"python", | ||
"-m", | ||
"examples.fedprox_example.server", | ||
"--config_path", | ||
config_path, | ||
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | ||
server_output = ps.stdout | ||
stdout=PIPE, | ||
stderr=STDOUT, | ||
) | ||
|
||
# Sleep for 20 seconds to allow the server to come up. | ||
# TODO fix this by capturing the output and parsing it | ||
time.sleep(20) | ||
# reads lines from the server output in search of the startup log message | ||
# times out after 20s of inactivity if it doesn't find the log message | ||
full_server_output = "" | ||
startup_message = "FL starting" | ||
output_found = False | ||
while True: | ||
try: | ||
server_output_in_bytes = await asyncio.wait_for(server_process.stdout.readline(), 20) | ||
server_output = server_output_in_bytes.decode() | ||
full_server_output += server_output | ||
except asyncio.TimeoutError: | ||
logger.error(f"Timeout waiting for server startup message '{startup_message}'") | ||
break | ||
logger.debug(f"Server output: {server_output}") | ||
if startup_message in server_output: | ||
output_found = True | ||
break | ||
|
||
assert output_found, f"Startup log message '{startup_message}' not found in server output." | ||
|
||
logger.info("Server started") | ||
|
||
# Start n number of clients and divert the outputs to their own files | ||
client_outputs = [] | ||
# Start n number of clients and capture their process objects | ||
client_processes = [] | ||
for i in range(n_clients_to_start): | ||
logger.info(f"Starting client {i}") | ||
ps = subprocess.Popen([ | ||
client_process = await asyncio.create_subprocess_exec( | ||
"nohup", | ||
"python", | ||
"-m", | ||
"examples.fedprox_example.client", | ||
"--dataset_path", | ||
dataset_path, | ||
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | ||
client_outputs.append(ps.stdout) | ||
stdout=PIPE, | ||
stderr=STDOUT, | ||
) | ||
client_processes.append(client_process) | ||
|
||
# Collecting the clients output until their processes finish | ||
full_client_outputs = [""] * n_clients_to_start | ||
# Clients that have finished execution are set to None, so the loop finishes when all of them are None | ||
# or in other words, while there are still any valid process objects in the list | ||
while any(client_processes): | ||
for i in range(len(client_processes)): | ||
if client_processes[i] is None: | ||
# Clients that have finished execution are set to None | ||
continue | ||
|
||
client_output_in_bytes = await asyncio.wait_for(client_processes[i].stdout.readline(), 20) | ||
client_output = client_output_in_bytes.decode() | ||
logger.debug(f"Client {i} output: {client_output}") | ||
|
||
full_client_outputs[i] += client_output | ||
|
||
# TODO make asserts | ||
# checking for clients with failure exit codes | ||
client_return_code = client_processes[i].returncode | ||
assert client_return_code is None or (client_return_code is not None and client_return_code == 0), \ | ||
f"Client {i} exited with code {client_return_code}" | ||
|
||
if client_output is None or len(client_output) == 0 or client_return_code == 0: | ||
logger.info(f"Client {i} finished execution") | ||
# Setting the client that has finished to None | ||
client_processes[i] = None | ||
|
||
logger.info("All clients finished execution") | ||
|
||
# now wait for the server to finish | ||
while True: | ||
logger.info(f"Server output: {server_output.readline().decode()}") | ||
for i in range(len(client_outputs)): | ||
logger.info(f"Client {i} output: {client_outputs[i].readline().decode()}") | ||
try: | ||
server_output_in_bytes = await asyncio.wait_for(server_process.stdout.readline(), 20) | ||
server_output = server_output_in_bytes.decode() | ||
full_server_output += server_output | ||
logger.debug(f"Server output: {server_output}") | ||
except asyncio.TimeoutError: | ||
logger.debug(f"Server log message retrieval timed out, it has likely finished execution") | ||
break | ||
|
||
# checking for clients with failure exit codes | ||
server_return_code = server_process.returncode | ||
assert server_return_code is None or (server_return_code is not None and server_return_code == 0), \ | ||
f"Server exited with code {server_return_code}" | ||
|
||
if server_output is None or len(server_output) == 0 or server_return_code == 0: | ||
break | ||
|
||
logger.info("Server has finished execution") | ||
|
||
assert "error" not in full_server_output.lower(), "Error message found for server" | ||
for i in range(full_client_outputs): | ||
assert "error" not in full_client_outputs[i].lower(), f"Error message found for client {i}" | ||
# TODO pull this number from the config | ||
assert "Current FL Round: 15" in full_client_outputs[i], f"Last FL round message not found for client {i}" | ||
assert "Disconnect and shut down" in full_client_outputs[i], f"Shutdown message not found for client {i}" | ||
|
||
|
||
if __name__ == "__main__": | ||
run_smoke_test() | ||
loop = asyncio.get_event_loop() | ||
loop.run_until_complete(run_smoke_test()) | ||
loop.close() |