-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Adding Material Dashboard bootstrap template * Adding venv to gitignore * Bump * Fixing static code errors * Making it into a next js app * Deleting old code * Removing about link * cleanup * Adding information about building and running the project into readmen and contributing files. * Moving things around to comply with what poetry expects * done endpoint * added tests, updated readme * Adding return type for client connect * Fixing tests workflow * Testing updated precommit checks * Adding inits * Adding init, fixing mypy error * Adding docstrings * Fixing tests path * Add fiest pass of functional launcher * Updated code that redirects output from server and client processes to files. Slightly modify isort config to make it compatible with black. Add documentation. * Add a bit more documentation, fix typing and pre-commit * Address Marcelos PR suggestions including test restructuing, asserting based on presence of string in log file, ability to specify name of server and client log files and a few other small changes * Fix formatting * Add missing arguments from docstrings, move launch file to not be in the root of api, ensure dump happens after shutdown for client and server, consistent format of log file related input in launch func between client and server and other small changes * Fix pre-commit formatting issue * Fix grammar error in docstring * Remove unwanted metrics files * update path of test_launch --------- Co-authored-by: Marcelo Lotif <[email protected]>
- Loading branch information
1 parent
125d70e
commit 927405b
Showing
10 changed files
with
3,907 additions
and
283 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import logging | ||
import sys | ||
import time | ||
from multiprocessing import Process | ||
from typing import Callable, Sequence | ||
|
||
import flwr as fl | ||
from fl4health.clients.basic_client import BasicClient | ||
from fl4health.server.base_server import FlServer | ||
from flwr.common.logger import DEFAULT_FORMATTER | ||
from flwr.server import ServerConfig | ||
|
||
|
||
def redirect_logging_from_console_to_file(log_file_path: str) -> None: | ||
""" | ||
Function that redirects loggers outputing to console to specified file. | ||
Args: | ||
log_file_path (str): The path to the file to log to. | ||
""" | ||
|
||
# Define file handler to log to and set format | ||
fh = logging.FileHandler(log_file_path) | ||
fh.setFormatter(DEFAULT_FORMATTER) | ||
|
||
# Loop through existing loggers to check if they have one or more streamhandlers | ||
# If they do, remove them (to prevent logging to the console) and add filehandler | ||
for name in logging.root.manager.loggerDict: | ||
logger = logging.getLogger(name) | ||
if not all([isinstance(h, logging.StreamHandler) is False for h in logger.handlers]): | ||
logger.handlers = [h for h in logger.handlers if not isinstance(h, logging.StreamHandler)] | ||
logger.addHandler(fh) | ||
|
||
|
||
def start_server( | ||
server_constructor: Callable[..., FlServer], | ||
server_address: str, | ||
n_server_rounds: int, | ||
server_log_file_name: str, | ||
) -> None: | ||
""" | ||
Function to start server. Redirects logging to console, stdout and stderr to file. | ||
Args: | ||
server_constructor (Callable[FlServer]): Callable that constructs FL server. | ||
server_address (str): String of <IP>:<PORT> to make server available. | ||
n_server_rounds (str): The number of rounds to perform FL | ||
server_log_file_name (str): The name of the server log file. | ||
""" | ||
redirect_logging_from_console_to_file(server_log_file_name) | ||
log_file = open(server_log_file_name, "a") | ||
# Send remaining ouput (ie print) from stdout and stderr to file | ||
sys.stdout = log_file | ||
sys.stderr = log_file | ||
server = server_constructor() | ||
fl.server.start_server( | ||
server=server, | ||
server_address=server_address, | ||
config=ServerConfig(num_rounds=n_server_rounds), | ||
) | ||
server.shutdown() | ||
server.metrics_reporter.dump() | ||
log_file.close() | ||
|
||
|
||
def start_client(client: BasicClient, server_address: str, client_log_file_name: str) -> None: | ||
""" | ||
Function to start client. Redirects logging to console, stdout and stderr to file. | ||
Args: | ||
client (BasicClient): BasicClient instance to launch. | ||
server_address (str): String of <IP>:<PORT> where the server is available. | ||
client_log_file_name (str): The name of the client log file. | ||
""" | ||
redirect_logging_from_console_to_file(client_log_file_name) | ||
log_file = open(client_log_file_name, "a") | ||
# Send remaining ouput (ie print) from stdout and stderr to file | ||
sys.stdout = log_file | ||
sys.stderr = log_file | ||
fl.client.start_numpy_client(server_address=server_address, client=client) | ||
client.shutdown() | ||
client.metrics_reporter.dump() | ||
log_file.close() | ||
|
||
|
||
def launch_server( | ||
server_constructor: Callable[..., FlServer], | ||
server_address: str, | ||
n_server_rounds: int, | ||
server_log_file_name: str, | ||
seconds_to_sleep: int = 10, | ||
) -> Process: | ||
""" | ||
Function that spawns a process that starts FL server. | ||
Args: | ||
server_constructor (Callable[FlServer]): Callable that constructs FL server. | ||
server_address (str): String of <IP>:<PORT> to make server available. | ||
n_server_rounds (str): The number of rounds to perform FL. | ||
server_log_file_name (str): The name of the log file for the server. | ||
seconds_to_sleep (int): The number of seconds to sleep before launching server. | ||
Returns: | ||
Process: The process running the FL server. | ||
""" | ||
server_process = Process( | ||
target=start_server, | ||
args=( | ||
server_constructor, | ||
server_address, | ||
n_server_rounds, | ||
server_log_file_name, | ||
), | ||
) | ||
server_process.start() | ||
time.sleep(seconds_to_sleep) | ||
return server_process | ||
|
||
|
||
def launch_client(client: BasicClient, server_address: str, client_log_file_name: str) -> None: | ||
""" | ||
Function that spawns a process that starts FL client. | ||
Args: | ||
client (BasicClient): BasicClient instance to launch. | ||
server_address (str): String of <IP>:<PORT> to make server available. | ||
client_log_file_name: (Optional[str]): The name used for the client log file. | ||
""" | ||
client_process = Process(target=start_client, args=(client, server_address, client_log_file_name)) | ||
client_process.start() | ||
|
||
|
||
def launch( | ||
server_constructor: Callable[..., FlServer], | ||
server_address: str, | ||
n_server_rounds: int, | ||
clients: Sequence[BasicClient], | ||
server_base_log_file_name: str = "server", | ||
client_base_log_file_name: str = "client", | ||
) -> None: | ||
""" | ||
Function to launch FL experiment. First launches server than subsequently clients. | ||
Joins server process after clients are launched to block until FL is complete. | ||
(Server is last to finish executing) | ||
Args: | ||
server_constructor (Callable[FlServer]): Callable that constructs FL server. | ||
server_address (str): String of <IP>:<PORT> to make server available. | ||
n_server_rounds (str): The number of rounds to perform FL | ||
clients (Sequence[BasicClient]): List of BasicClient instances to launch. | ||
server_base_log_file_name: (Optional[str]): The name used for the server log file. | ||
client_base_log_file_name: (Optional[str]): The base name used for the client log file. | ||
_{i}.out appended to name to get final client log file name. | ||
""" | ||
server_log_file_name = f"{server_base_log_file_name}.out" | ||
server_process = launch_server(server_constructor, server_address, n_server_rounds, server_log_file_name) | ||
for i, client in enumerate(clients): | ||
client_log_file_name = f"{client_base_log_file_name}_{i}.out" | ||
launch_client(client, server_address, client_log_file_name) | ||
server_process.join() |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os | ||
import re | ||
import tempfile | ||
from functools import partial | ||
from pathlib import Path | ||
from typing import Callable, Dict | ||
|
||
import torch | ||
from fl4health.server.base_server import FlServer | ||
|
||
from florist.api.launchers.launch import launch | ||
from florist.tests.utils.api.fl4health_utils import MnistClient, get_server_fedavg | ||
from florist.tests.utils.api.models import MnistNet | ||
|
||
|
||
def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]: | ||
return { | ||
"batch_size": batch_size, | ||
"current_server_round": current_server_round, | ||
"local_epochs": local_epochs, | ||
} | ||
|
||
|
||
def get_server( | ||
fit_config: Callable[..., Dict[str, int]] = fit_config, | ||
n_clients: int = 2, | ||
batch_size: int = 8, | ||
local_epochs: int = 1, | ||
) -> FlServer: | ||
fit_config_fn = partial(fit_config, batch_size, local_epochs) | ||
server = get_server_fedavg(model=MnistNet(), n_clients=n_clients, fit_config_fn=fit_config_fn) | ||
return server | ||
|
||
|
||
def assert_string_in_file(file_path: str, search_string: str) -> bool: | ||
with open(file_path, "r") as f: | ||
file_contents = f.read() | ||
match = re.search(search_string, file_contents) | ||
return match is not None | ||
|
||
|
||
def test_launch() -> None: | ||
n_clients = 2 | ||
n_server_rounds = 2 | ||
server_address = "0.0.0.0:8080" | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
client_data_paths = [Path(f"{temp_dir}/{i}") for i in range(n_clients)] | ||
for client_data_path in client_data_paths: | ||
os.mkdir(client_data_path) | ||
clients = [MnistClient(client_data_path, [], torch.device("cpu")) for client_data_path in client_data_paths] | ||
|
||
server_path = os.path.join(temp_dir, "server") | ||
client_base_path = f"{temp_dir}/client" | ||
launch( | ||
get_server, | ||
server_address, | ||
n_server_rounds, | ||
clients, | ||
server_path, | ||
client_base_path, | ||
) | ||
|
||
assert_string_in_file(f"{server_path}.out", "FL finished in") |
Empty file.
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from typing import Callable, List, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
from fl4health.client_managers.base_sampling_manager import SimpleClientManager | ||
from fl4health.clients.basic_client import BasicClient | ||
from fl4health.server.base_server import FlServer | ||
from fl4health.utils.load_data import load_mnist_data | ||
from flwr.common.parameter import ndarrays_to_parameters | ||
from flwr.common.typing import Config, Metrics, Parameters | ||
from flwr.server.strategy import FedAvg | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
||
from florist.tests.utils.api.models import MnistNet | ||
|
||
|
||
class MnistClient(BasicClient): | ||
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: | ||
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"]) | ||
return train_loader, val_loader | ||
|
||
def get_model(self, config: Config) -> nn.Module: | ||
return MnistNet() | ||
|
||
def get_optimizer(self, config: Config) -> Optimizer: | ||
opt = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) | ||
return opt | ||
|
||
def get_criterion(self, config: Config) -> _Loss: | ||
return torch.nn.CrossEntropyLoss() | ||
|
||
|
||
def metric_aggregation( | ||
all_client_metrics: List[Tuple[int, Metrics]], | ||
) -> Tuple[int, Metrics]: | ||
aggregated_metrics: Metrics = {} | ||
total_examples = 0 | ||
# Run through all of the metrics | ||
for num_examples_on_client, client_metrics in all_client_metrics: | ||
total_examples += num_examples_on_client | ||
for metric_name, metric_value in client_metrics.items(): | ||
# Here we assume each metric is normalized by the number of examples on the client. So we scale up to | ||
# get the "raw" value | ||
if isinstance(metric_value, float): | ||
current_metric_value = aggregated_metrics.get(metric_name, 0.0) | ||
assert isinstance(current_metric_value, float) | ||
aggregated_metrics[metric_name] = current_metric_value + num_examples_on_client * metric_value | ||
elif isinstance(metric_value, int): | ||
current_metric_value = aggregated_metrics.get(metric_name, 0) | ||
assert isinstance(current_metric_value, int) | ||
aggregated_metrics[metric_name] = current_metric_value + num_examples_on_client * metric_value | ||
else: | ||
raise ValueError("Metric type is not supported") | ||
return total_examples, aggregated_metrics | ||
|
||
|
||
def normalize_metrics(total_examples: int, aggregated_metrics: Metrics) -> Metrics: | ||
# Normalize all metric values by the total count of examples seen. | ||
normalized_metrics: Metrics = {} | ||
for metric_name, metric_value in aggregated_metrics.items(): | ||
if isinstance(metric_value, float) or isinstance(metric_value, int): | ||
normalized_metrics[metric_name] = metric_value / total_examples | ||
return normalized_metrics | ||
|
||
|
||
def fit_metrics_aggregation_fn( | ||
all_client_metrics: List[Tuple[int, Metrics]], | ||
) -> Metrics: | ||
# This function is run by the server to aggregate metrics returned by each clients fit function | ||
# NOTE: The first value of the tuple is number of examples for FedAvg | ||
total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) | ||
return normalize_metrics(total_examples, aggregated_metrics) | ||
|
||
|
||
def evaluate_metrics_aggregation_fn( | ||
all_client_metrics: List[Tuple[int, Metrics]], | ||
) -> Metrics: | ||
# This function is run by the server to aggregate metrics returned by each clients evaluate function | ||
# NOTE: The first value of the tuple is number of examples for FedAvg | ||
total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) | ||
return normalize_metrics(total_examples, aggregated_metrics) | ||
|
||
|
||
def get_initial_model_parameters(model: nn.Module) -> Parameters: | ||
# Initializing the model parameters on the server side. | ||
return ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()]) | ||
|
||
|
||
def get_fedavg_strategy(model: nn.Module, n_clients: int, fit_config_fn: Callable) -> FedAvg: | ||
strategy = FedAvg( | ||
min_fit_clients=n_clients, | ||
min_evaluate_clients=n_clients, | ||
min_available_clients=n_clients, | ||
on_fit_config_fn=fit_config_fn, | ||
on_evaluate_config_fn=fit_config_fn, | ||
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, | ||
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, | ||
initial_parameters=get_initial_model_parameters(model), | ||
) | ||
return strategy | ||
|
||
|
||
def get_server_fedavg(model: nn.Module, n_clients: int, fit_config_fn: Callable) -> FlServer: | ||
strategy = get_fedavg_strategy(model, n_clients, fit_config_fn) | ||
client_manager = SimpleClientManager() | ||
server = FlServer(strategy=strategy, client_manager=client_manager) | ||
return server |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class MnistNet(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 8, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(8, 16, 5) | ||
self.fc1 = nn.Linear(16 * 4 * 4, 120) | ||
self.fc2 = nn.Linear(120, 10) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 4 * 4) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
return x |
Oops, something went wrong.