Skip to content

Commit

Permalink
Launch client server async (#7)
Browse files Browse the repository at this point in the history
* Adding Material Dashboard bootstrap template

* Adding venv to gitignore

* Bump

* Fixing static code errors

* Making it into a next js app

* Deleting old code

* Removing about link

* cleanup

* Adding information about building and running the project into readmen and contributing files.

* Moving things around to comply with what poetry expects

* done endpoint

* added tests, updated readme

* Adding return type for client connect

* Fixing tests workflow

* Testing updated precommit checks

* Adding inits

* Adding init, fixing mypy error

* Adding docstrings

* Fixing tests path

* Add fiest pass of functional launcher

* Updated code that redirects output from server and client processes to files. Slightly modify isort config to make it compatible with black. Add documentation.

* Add a bit more documentation, fix typing and pre-commit

* Address Marcelos PR suggestions including test restructuing, asserting based on presence of string in log file, ability to specify name of server and client log files and a few other small changes

* Fix formatting

* Add missing arguments from docstrings, move launch file to not be in the root of api, ensure dump happens after shutdown for client and server, consistent format of log file related input in launch func between client and server and other small changes

* Fix pre-commit formatting issue

* Fix grammar error in docstring

* Remove unwanted metrics files

* update path of test_launch

---------

Co-authored-by: Marcelo Lotif <[email protected]>
  • Loading branch information
jewelltaylor and lotif authored Feb 29, 2024
1 parent 125d70e commit 927405b
Show file tree
Hide file tree
Showing 10 changed files with 3,907 additions and 283 deletions.
160 changes: 160 additions & 0 deletions florist/api/launchers/launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import logging
import sys
import time
from multiprocessing import Process
from typing import Callable, Sequence

import flwr as fl
from fl4health.clients.basic_client import BasicClient
from fl4health.server.base_server import FlServer
from flwr.common.logger import DEFAULT_FORMATTER
from flwr.server import ServerConfig


def redirect_logging_from_console_to_file(log_file_path: str) -> None:
"""
Function that redirects loggers outputing to console to specified file.
Args:
log_file_path (str): The path to the file to log to.
"""

# Define file handler to log to and set format
fh = logging.FileHandler(log_file_path)
fh.setFormatter(DEFAULT_FORMATTER)

# Loop through existing loggers to check if they have one or more streamhandlers
# If they do, remove them (to prevent logging to the console) and add filehandler
for name in logging.root.manager.loggerDict:
logger = logging.getLogger(name)
if not all([isinstance(h, logging.StreamHandler) is False for h in logger.handlers]):
logger.handlers = [h for h in logger.handlers if not isinstance(h, logging.StreamHandler)]
logger.addHandler(fh)


def start_server(
server_constructor: Callable[..., FlServer],
server_address: str,
n_server_rounds: int,
server_log_file_name: str,
) -> None:
"""
Function to start server. Redirects logging to console, stdout and stderr to file.
Args:
server_constructor (Callable[FlServer]): Callable that constructs FL server.
server_address (str): String of <IP>:<PORT> to make server available.
n_server_rounds (str): The number of rounds to perform FL
server_log_file_name (str): The name of the server log file.
"""
redirect_logging_from_console_to_file(server_log_file_name)
log_file = open(server_log_file_name, "a")
# Send remaining ouput (ie print) from stdout and stderr to file
sys.stdout = log_file
sys.stderr = log_file
server = server_constructor()
fl.server.start_server(
server=server,
server_address=server_address,
config=ServerConfig(num_rounds=n_server_rounds),
)
server.shutdown()
server.metrics_reporter.dump()
log_file.close()


def start_client(client: BasicClient, server_address: str, client_log_file_name: str) -> None:
"""
Function to start client. Redirects logging to console, stdout and stderr to file.
Args:
client (BasicClient): BasicClient instance to launch.
server_address (str): String of <IP>:<PORT> where the server is available.
client_log_file_name (str): The name of the client log file.
"""
redirect_logging_from_console_to_file(client_log_file_name)
log_file = open(client_log_file_name, "a")
# Send remaining ouput (ie print) from stdout and stderr to file
sys.stdout = log_file
sys.stderr = log_file
fl.client.start_numpy_client(server_address=server_address, client=client)
client.shutdown()
client.metrics_reporter.dump()
log_file.close()


def launch_server(
server_constructor: Callable[..., FlServer],
server_address: str,
n_server_rounds: int,
server_log_file_name: str,
seconds_to_sleep: int = 10,
) -> Process:
"""
Function that spawns a process that starts FL server.
Args:
server_constructor (Callable[FlServer]): Callable that constructs FL server.
server_address (str): String of <IP>:<PORT> to make server available.
n_server_rounds (str): The number of rounds to perform FL.
server_log_file_name (str): The name of the log file for the server.
seconds_to_sleep (int): The number of seconds to sleep before launching server.
Returns:
Process: The process running the FL server.
"""
server_process = Process(
target=start_server,
args=(
server_constructor,
server_address,
n_server_rounds,
server_log_file_name,
),
)
server_process.start()
time.sleep(seconds_to_sleep)
return server_process


def launch_client(client: BasicClient, server_address: str, client_log_file_name: str) -> None:
"""
Function that spawns a process that starts FL client.
Args:
client (BasicClient): BasicClient instance to launch.
server_address (str): String of <IP>:<PORT> to make server available.
client_log_file_name: (Optional[str]): The name used for the client log file.
"""
client_process = Process(target=start_client, args=(client, server_address, client_log_file_name))
client_process.start()


def launch(
server_constructor: Callable[..., FlServer],
server_address: str,
n_server_rounds: int,
clients: Sequence[BasicClient],
server_base_log_file_name: str = "server",
client_base_log_file_name: str = "client",
) -> None:
"""
Function to launch FL experiment. First launches server than subsequently clients.
Joins server process after clients are launched to block until FL is complete.
(Server is last to finish executing)
Args:
server_constructor (Callable[FlServer]): Callable that constructs FL server.
server_address (str): String of <IP>:<PORT> to make server available.
n_server_rounds (str): The number of rounds to perform FL
clients (Sequence[BasicClient]): List of BasicClient instances to launch.
server_base_log_file_name: (Optional[str]): The name used for the server log file.
client_base_log_file_name: (Optional[str]): The base name used for the client log file.
_{i}.out appended to name to get final client log file name.
"""
server_log_file_name = f"{server_base_log_file_name}.out"
server_process = launch_server(server_constructor, server_address, n_server_rounds, server_log_file_name)
for i, client in enumerate(clients):
client_log_file_name = f"{client_base_log_file_name}_{i}.out"
launch_client(client, server_address, client_log_file_name)
server_process.join()
File renamed without changes.
64 changes: 64 additions & 0 deletions florist/tests/integration/api/launchers/test_launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import re
import tempfile
from functools import partial
from pathlib import Path
from typing import Callable, Dict

import torch
from fl4health.server.base_server import FlServer

from florist.api.launchers.launch import launch
from florist.tests.utils.api.fl4health_utils import MnistClient, get_server_fedavg
from florist.tests.utils.api.models import MnistNet


def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]:
return {
"batch_size": batch_size,
"current_server_round": current_server_round,
"local_epochs": local_epochs,
}


def get_server(
fit_config: Callable[..., Dict[str, int]] = fit_config,
n_clients: int = 2,
batch_size: int = 8,
local_epochs: int = 1,
) -> FlServer:
fit_config_fn = partial(fit_config, batch_size, local_epochs)
server = get_server_fedavg(model=MnistNet(), n_clients=n_clients, fit_config_fn=fit_config_fn)
return server


def assert_string_in_file(file_path: str, search_string: str) -> bool:
with open(file_path, "r") as f:
file_contents = f.read()
match = re.search(search_string, file_contents)
return match is not None


def test_launch() -> None:
n_clients = 2
n_server_rounds = 2
server_address = "0.0.0.0:8080"

with tempfile.TemporaryDirectory() as temp_dir:
client_data_paths = [Path(f"{temp_dir}/{i}") for i in range(n_clients)]
for client_data_path in client_data_paths:
os.mkdir(client_data_path)
clients = [MnistClient(client_data_path, [], torch.device("cpu")) for client_data_path in client_data_paths]

server_path = os.path.join(temp_dir, "server")
client_base_path = f"{temp_dir}/client"
launch(
get_server,
server_address,
n_server_rounds,
clients,
server_path,
client_base_path,
)

assert_string_in_file(f"{server_path}.out", "FL finished in")
Empty file added florist/tests/unit/__init__.py
Empty file.
File renamed without changes.
Empty file.
109 changes: 109 additions & 0 deletions florist/tests/utils/api/fl4health_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Callable, List, Tuple

import torch
import torch.nn as nn
from fl4health.client_managers.base_sampling_manager import SimpleClientManager
from fl4health.clients.basic_client import BasicClient
from fl4health.server.base_server import FlServer
from fl4health.utils.load_data import load_mnist_data
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Config, Metrics, Parameters
from flwr.server.strategy import FedAvg
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from florist.tests.utils.api.models import MnistNet


class MnistClient(BasicClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"])
return train_loader, val_loader

def get_model(self, config: Config) -> nn.Module:
return MnistNet()

def get_optimizer(self, config: Config) -> Optimizer:
opt = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
return opt

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()


def metric_aggregation(
all_client_metrics: List[Tuple[int, Metrics]],
) -> Tuple[int, Metrics]:
aggregated_metrics: Metrics = {}
total_examples = 0
# Run through all of the metrics
for num_examples_on_client, client_metrics in all_client_metrics:
total_examples += num_examples_on_client
for metric_name, metric_value in client_metrics.items():
# Here we assume each metric is normalized by the number of examples on the client. So we scale up to
# get the "raw" value
if isinstance(metric_value, float):
current_metric_value = aggregated_metrics.get(metric_name, 0.0)
assert isinstance(current_metric_value, float)
aggregated_metrics[metric_name] = current_metric_value + num_examples_on_client * metric_value
elif isinstance(metric_value, int):
current_metric_value = aggregated_metrics.get(metric_name, 0)
assert isinstance(current_metric_value, int)
aggregated_metrics[metric_name] = current_metric_value + num_examples_on_client * metric_value
else:
raise ValueError("Metric type is not supported")
return total_examples, aggregated_metrics


def normalize_metrics(total_examples: int, aggregated_metrics: Metrics) -> Metrics:
# Normalize all metric values by the total count of examples seen.
normalized_metrics: Metrics = {}
for metric_name, metric_value in aggregated_metrics.items():
if isinstance(metric_value, float) or isinstance(metric_value, int):
normalized_metrics[metric_name] = metric_value / total_examples
return normalized_metrics


def fit_metrics_aggregation_fn(
all_client_metrics: List[Tuple[int, Metrics]],
) -> Metrics:
# This function is run by the server to aggregate metrics returned by each clients fit function
# NOTE: The first value of the tuple is number of examples for FedAvg
total_examples, aggregated_metrics = metric_aggregation(all_client_metrics)
return normalize_metrics(total_examples, aggregated_metrics)


def evaluate_metrics_aggregation_fn(
all_client_metrics: List[Tuple[int, Metrics]],
) -> Metrics:
# This function is run by the server to aggregate metrics returned by each clients evaluate function
# NOTE: The first value of the tuple is number of examples for FedAvg
total_examples, aggregated_metrics = metric_aggregation(all_client_metrics)
return normalize_metrics(total_examples, aggregated_metrics)


def get_initial_model_parameters(model: nn.Module) -> Parameters:
# Initializing the model parameters on the server side.
return ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()])


def get_fedavg_strategy(model: nn.Module, n_clients: int, fit_config_fn: Callable) -> FedAvg:
strategy = FedAvg(
min_fit_clients=n_clients,
min_evaluate_clients=n_clients,
min_available_clients=n_clients,
on_fit_config_fn=fit_config_fn,
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(model),
)
return strategy


def get_server_fedavg(model: nn.Module, n_clients: int, fit_config_fn: Callable) -> FlServer:
strategy = get_fedavg_strategy(model, n_clients, fit_config_fn)
client_manager = SimpleClientManager()
server = FlServer(strategy=strategy, client_manager=client_manager)
return server
21 changes: 21 additions & 0 deletions florist/tests/utils/api/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class MnistNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 8, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(8, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
Loading

0 comments on commit 927405b

Please sign in to comment.