Skip to content

Commit

Permalink
Using python script instead of sh
Browse files Browse the repository at this point in the history
  • Loading branch information
lotif committed Nov 24, 2023
1 parent f17fcbd commit 800ba73
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/smoke_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ jobs:
pip install --upgrade pip
pip install $(grep -v '^torchdata\|^torchtext\|^torcheval' requirements.txt)
- name: Run Script
run: sh examples/utils/run_fl_local.sh # TODO create new script to check for logs, move it to the tests folder
run: python -m tests.smoke_tests.run_smoke_test
- name: Clearing pip for space
run: pip uninstall -y -r requirements.txt
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ We use the standard git development flow of branch and merge to main with PRs on

The library dependencies and those for development are listed in the `pyproject.toml` and `requirements.txt` files. You may use whatever virtual environment management tool that you would like. These include conda, poetry, and virtualenv. Poetry is used to produce our releases, which are managed and automated by GitHub.

The easiest way to create and activate a virtual environment is
The easiest way to create and activate a virtual environment is by using the [virtualenv](https://pypi.org/project/virtualenv/) package:
```bash
virtualenv "ENV_PATH"
source "ENV_PATH/bin/activate"
Expand Down
Empty file added tests/smoke_tests/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions tests/smoke_tests/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Parameters that describe server
n_server_rounds: 15 # The number of rounds to run FL

# Following the adaptive proximal weight setup in Appendix C3.3 of the FedProx paper: https://arxiv.org/pdf/1812.06127.pdf
# Server decides to increase or decrease the proximal weight based on the average training losses of clients that were sent during training.
# The updated proximal weight is then sent to the clients for the next round of training.
# Set the initial proximal weight to 0.0 for adaptive proximal weight setup.
adaptive_proximal_weight: True # Whether to use adaptive proximal weight or not
proximal_weight : 0.0 # Initial proximal weight
proximal_weight_delta : 0.1 # The amount by which to increase or decrease the proximal weight, if adaptive_proximal_weight is True
proximal_weight_patience : 5 # The number of rounds to wait before increasing or decreasing the proximal weight, if adaptive_proximal_weight is True

# Parameters that describe clients
n_clients: 3 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client
batch_size: 128 # The batch size for client training

reporting_config:
enabled: False
project_name: FL4Health # Name of the project under which everything should be logged
run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name
group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored
entity: "your_entity_here" # WandB user name
notes: "Testing WB reporting"
tags: ["Test", "FedProx"]
55 changes: 55 additions & 0 deletions tests/smoke_tests/run_smoke_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging
import subprocess
import time

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def run_smoke_test(
n_clients_to_start: int = 4,
config_path: str = "tests/smoke_tests/config.yaml",
dataset_path: str = "examples/datasets/mnist_data/",
) -> None:
# Start the server, divert the outputs to a server file
logger.info("Starting server")

ps = subprocess.Popen([
"nohup",
"python",
"-m",
"examples.fedprox_example.server",
"--config_path",
config_path,
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
server_output = ps.stdout

# Sleep for 20 seconds to allow the server to come up.
# TODO fix this by capturing the output and parsing it
time.sleep(20)

# Start n number of clients and divert the outputs to their own files
client_outputs = []
for i in range(n_clients_to_start):
logger.info(f"Starting client {i}")
ps = subprocess.Popen([
"nohup",
"python",
"-m",
"examples.fedprox_example.client",
"--dataset_path",
dataset_path,
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
client_outputs.append(ps.stdout)

# TODO make asserts

while True:
logger.info(f"Server output: {server_output.readline()}")
for i in range(len(client_outputs)):
logger.info(f"Client {i} output: {client_outputs[i].readline()}")


if __name__ == "__main__":
run_smoke_test()

0 comments on commit 800ba73

Please sign in to comment.