Skip to content

Commit

Permalink
WIP collecting and checking the output
Browse files Browse the repository at this point in the history
  • Loading branch information
lotif committed Nov 27, 2023
1 parent 933882c commit 46361b1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 23 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/smoke_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: 3.9
cache: 'pip' # Display the Python version being used
cache: 'pip'
# Display the Python version being used
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install Requirements
Expand Down
2 changes: 1 addition & 1 deletion tests/smoke_tests/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ proximal_weight_delta : 0.1 # The amount by which to increase or decrease the pr
proximal_weight_patience : 5 # The number of rounds to wait before increasing or decreasing the proximal weight, if adaptive_proximal_weight is True

# Parameters that describe clients
n_clients: 3 # The number of clients in the FL experiment
n_clients: 4 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client
batch_size: 128 # The batch size for client training

Expand Down
117 changes: 96 additions & 21 deletions tests/smoke_tests/run_smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,130 @@
import asyncio
import logging
import subprocess
import time
from asyncio.subprocess import PIPE, STDOUT

logging.basicConfig()
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.DEBUG,
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def run_smoke_test(
async def run_smoke_test(
n_clients_to_start: int = 4,
config_path: str = "tests/smoke_tests/config.yaml",
dataset_path: str = "examples/datasets/mnist_data/",
) -> None:
# Start the server, divert the outputs to a server file
logger.info("Starting server")

ps = subprocess.Popen([
server_process = await asyncio.create_subprocess_exec(
"nohup",
"python",
"-m",
"examples.fedprox_example.server",
"--config_path",
config_path,
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
server_output = ps.stdout
stdout=PIPE,
stderr=STDOUT,
)

# Sleep for 20 seconds to allow the server to come up.
# TODO fix this by capturing the output and parsing it
time.sleep(20)
# reads lines from the server output in search of the startup log message
# times out after 20s of inactivity if it doesn't find the log message
full_server_output = ""
startup_message = "FL starting"
output_found = False
while True:
try:
server_output_in_bytes = await asyncio.wait_for(server_process.stdout.readline(), 20)
server_output = server_output_in_bytes.decode()
full_server_output += server_output
except asyncio.TimeoutError:
logger.error(f"Timeout waiting for server startup message '{startup_message}'")
break
logger.debug(f"Server output: {server_output}")
if startup_message in server_output:
output_found = True
break

assert output_found, f"Startup log message '{startup_message}' not found in server output."

logger.info("Server started")

# Start n number of clients and divert the outputs to their own files
client_outputs = []
# Start n number of clients and capture their process objects
client_processes = []
for i in range(n_clients_to_start):
logger.info(f"Starting client {i}")
ps = subprocess.Popen([
client_process = await asyncio.create_subprocess_exec(
"nohup",
"python",
"-m",
"examples.fedprox_example.client",
"--dataset_path",
dataset_path,
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
client_outputs.append(ps.stdout)
stdout=PIPE,
stderr=STDOUT,
)
client_processes.append(client_process)

# Collecting the clients output until their processes finish
full_client_outputs = [""] * n_clients_to_start
# Clients that have finished execution are set to None, so the loop finishes when all of them are None
# or in other words, while there are still any valid process objects in the list
while any(client_processes):
for i in range(len(client_processes)):
if client_processes[i] is None:
# Clients that have finished execution are set to None
continue

client_output_in_bytes = await asyncio.wait_for(client_processes[i].stdout.readline(), 20)
client_output = client_output_in_bytes.decode()
logger.debug(f"Client {i} output: {client_output}")

full_client_outputs[i] += client_output

# TODO make asserts
# checking for clients with failure exit codes
client_return_code = client_processes[i].returncode
assert client_return_code is None or (client_return_code is not None and client_return_code == 0), \
f"Client {i} exited with code {client_return_code}"

if client_output is None or len(client_output) == 0 or client_return_code == 0:
logger.info(f"Client {i} finished execution")
# Setting the client that has finished to None
client_processes[i] = None

logger.info("All clients finished execution")

# now wait for the server to finish
while True:
logger.info(f"Server output: {server_output.readline().decode()}")
for i in range(len(client_outputs)):
logger.info(f"Client {i} output: {client_outputs[i].readline().decode()}")
try:
server_output_in_bytes = await asyncio.wait_for(server_process.stdout.readline(), 20)
server_output = server_output_in_bytes.decode()
full_server_output += server_output
logger.debug(f"Server output: {server_output}")
except asyncio.TimeoutError:
logger.debug(f"Server log message retrieval timed out, it has likely finished execution")
break

# checking for clients with failure exit codes
server_return_code = server_process.returncode
assert server_return_code is None or (server_return_code is not None and server_return_code == 0), \
f"Server exited with code {server_return_code}"

if server_output is None or len(server_output) == 0 or server_return_code == 0:
break

logger.info("Server has finished execution")

assert "error" not in full_server_output.lower(), "Error message found for server"
for i in range(full_client_outputs):
assert "error" not in full_client_outputs[i].lower(), f"Error message found for client {i}"
# TODO pull this number from the config
assert "Current FL Round: 15" in full_client_outputs[i], f"Last FL round message not found for client {i}"
assert "Disconnect and shut down" in full_client_outputs[i], f"Shutdown message not found for client {i}"


if __name__ == "__main__":
run_smoke_test()
loop = asyncio.get_event_loop()
loop.run_until_complete(run_smoke_test())
loop.close()

0 comments on commit 46361b1

Please sign in to comment.