generated from VectorInstitute/aieng-template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
82 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Parameters that describe server | ||
n_server_rounds: 15 # The number of rounds to run FL | ||
|
||
# Following the adaptive proximal weight setup in Appendix C3.3 of the FedProx paper: https://arxiv.org/pdf/1812.06127.pdf | ||
# Server decides to increase or decrease the proximal weight based on the average training losses of clients that were sent during training. | ||
# The updated proximal weight is then sent to the clients for the next round of training. | ||
# Set the initial proximal weight to 0.0 for adaptive proximal weight setup. | ||
adaptive_proximal_weight: True # Whether to use adaptive proximal weight or not | ||
proximal_weight : 0.0 # Initial proximal weight | ||
proximal_weight_delta : 0.1 # The amount by which to increase or decrease the proximal weight, if adaptive_proximal_weight is True | ||
proximal_weight_patience : 5 # The number of rounds to wait before increasing or decreasing the proximal weight, if adaptive_proximal_weight is True | ||
|
||
# Parameters that describe clients | ||
n_clients: 3 # The number of clients in the FL experiment | ||
local_epochs: 1 # The number of epochs to complete for client | ||
batch_size: 128 # The batch size for client training | ||
|
||
reporting_config: | ||
enabled: False | ||
project_name: FL4Health # Name of the project under which everything should be logged | ||
run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name | ||
group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored | ||
entity: "your_entity_here" # WandB user name | ||
notes: "Testing WB reporting" | ||
tags: ["Test", "FedProx"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import logging | ||
import subprocess | ||
import time | ||
|
||
logging.basicConfig() | ||
logger = logging.getLogger() | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
def run_smoke_test( | ||
n_clients_to_start: int = 4, | ||
config_path: str = "tests/smoke_tests/config.yaml", | ||
dataset_path: str = "examples/datasets/mnist_data/", | ||
) -> None: | ||
# Start the server, divert the outputs to a server file | ||
logger.info("Starting server") | ||
|
||
ps = subprocess.Popen([ | ||
"nohup", | ||
"python", | ||
"-m", | ||
"examples.fedprox_example.server", | ||
"--config_path", | ||
config_path, | ||
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | ||
server_output = ps.stdout | ||
|
||
# Sleep for 20 seconds to allow the server to come up. | ||
# TODO fix this by capturing the output and parsing it | ||
time.sleep(20) | ||
|
||
# Start n number of clients and divert the outputs to their own files | ||
client_outputs = [] | ||
for i in range(n_clients_to_start): | ||
logger.info(f"Starting client {i}") | ||
ps = subprocess.Popen([ | ||
"nohup", | ||
"python", | ||
"-m", | ||
"examples.fedprox_example.client", | ||
"--dataset_path", | ||
dataset_path, | ||
], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | ||
client_outputs.append(ps.stdout) | ||
|
||
# TODO make asserts | ||
|
||
while True: | ||
logger.info(f"Server output: {server_output.readline()}") | ||
for i in range(len(client_outputs)): | ||
logger.info(f"Client {i} output: {client_outputs[i].readline()}") | ||
|
||
|
||
if __name__ == "__main__": | ||
run_smoke_test() |