Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server storing FL configs and Consolidating Base Server Functionality #294

Merged
merged 15 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/static_code_checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ jobs:
virtual-environment: .venv/
# Ignoring vulnerability in cryptography
# Fix is 43.0.1 but flwr 1.9 depends on < 43
# PYSEC-2022-43145 seems like a bug in pip audit, we should probably try to remove the ignore at some point
ignore-vulns: |
GHSA-h4gh-qq45-vh27
GHSA-q34m-jh98-gwm2
GHSA-f9vj-2wh5-fj8j
PYSEC-2022-43145
5 changes: 3 additions & 2 deletions examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from examples.models.mnist_model import MnistNet
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServerWithCheckpointing
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand Down Expand Up @@ -63,8 +63,9 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServerWithCheckpointing
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand Down Expand Up @@ -64,8 +64,9 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
Expand Down
5 changes: 3 additions & 2 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServerWithCheckpointing
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand Down Expand Up @@ -64,8 +64,9 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
Expand Down
5 changes: 3 additions & 2 deletions examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.model_bases.autoencoders_base import VariationalAe
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServerWithCheckpointing
from fl4health.servers.base_server import FlServer
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
Expand Down Expand Up @@ -67,8 +67,9 @@ def main(config: Dict[str, Any]) -> None:
initial_loss_weight=config["initial_proximal_weight"],
)

server = FlServerWithCheckpointing(
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
Expand Down
2 changes: 1 addition & 1 deletion examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy, reporters=[JsonReporter()])
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=[JsonReporter()])

fl.server.start_server(
server=server,
Expand Down
5 changes: 3 additions & 2 deletions examples/basic_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServerWithCheckpointing
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand Down Expand Up @@ -62,8 +62,9 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
Expand Down
2 changes: 1 addition & 1 deletion examples/ditto_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = DittoServer(client_manager=client_manager, strategy=strategy)
server = DittoServer(client_manager=client_manager, fl_config=config, strategy=strategy)

fl.server.start_server(
server=server,
Expand Down
1 change: 1 addition & 0 deletions examples/dp_fed_examples/client_level_dp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def main(config: Dict[str, Any]) -> None:
)
server = ClientLevelDPFedAvgServer(
client_manager=client_manager,
fl_config=config,
strategy=strategy,
server_noise_multiplier=config["server_noise_multiplier"],
num_server_rounds=config["n_server_rounds"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def main(config: Dict[str, Any]) -> None:

server = ClientLevelDPFedAvgServer(
client_manager=client_manager,
fl_config=config,
strategy=strategy,
num_server_rounds=config["n_server_rounds"],
server_noise_multiplier=config["server_noise_multiplier"],
Expand Down
48 changes: 4 additions & 44 deletions examples/dp_fed_examples/instance_level_dp/server.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import argparse
import string
from collections.abc import Sequence
from functools import partial
from random import choices
from typing import Any, Dict, Optional

import flwr as fl
import torch.nn as nn
from flwr.common.parameter import parameters_to_ndarrays
from flwr.common.typing import Config
from flwr.server.client_manager import ClientManager

from examples.models.cnn_model import Net
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer, OpacusCheckpointer
from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer
from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.servers.instance_level_dp_server import InstanceLevelDpServer
from fl4health.strategies.basic_fedavg import OpacusBasicFedAvg
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -61,43 +56,6 @@ def fit_config(
)


class CifarInstanceLevelDPServerWithCheckpointing(InstanceLevelDpServer):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All functionality implemented here is now neatly contained in the parent class of InstanceLevelDpServer

def __init__(
self,
client_manager: ClientManager,
model: nn.Module,
noise_multiplier: float,
batch_size: int,
num_server_rounds: int,
strategy: OpacusBasicFedAvg,
local_epochs: Optional[int] = None,
local_steps: Optional[int] = None,
checkpointer: Optional[OpacusCheckpointer] = None,
reporters: Sequence[BaseReporter] | None = None,
delta: Optional[float] = None,
) -> None:
super().__init__(
client_manager,
noise_multiplier,
batch_size,
num_server_rounds,
strategy,
local_epochs,
local_steps,
checkpointer,
reporters,
delta,
)
self.parameter_exchanger = FullParameterExchanger()
self.server_model = model

# Setting up a hydration method so that checkpointing can happen on the server side
def _hydrate_model_for_checkpointing(self) -> nn.Module:
model_ndarrays = parameters_to_ndarrays(self.parameters)
self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model)
return self.server_model


def main(config: Dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
Expand Down Expand Up @@ -132,10 +90,12 @@ def main(config: Dict[str, Any]) -> None:
checkpoint_dir = "examples/dp_fed_examples/instance_level_dp/"
checkpoint_name = f"server_{client_name}_best_model.pkl"

server = CifarInstanceLevelDPServerWithCheckpointing(
server = InstanceLevelDpServer(
client_manager=client_manager,
fl_config=config,
model=initial_model,
checkpointer=BestLossOpacusCheckpointer(checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name),
parameter_exchanger=FullParameterExchanger(),
strategy=strategy,
noise_multiplier=config["noise_multiplier"],
local_epochs=config.get("local_epochs"),
Expand Down
1 change: 1 addition & 0 deletions examples/dp_scaffold_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def main(config: Dict[str, Any]) -> None:
client_manager = PoissonSamplingClientManager()
server = DPScaffoldServer(
client_manager=client_manager,
fl_config=config,
noise_multiplier=config["noise_multiplier"],
batch_size=config["batch_size"],
local_steps=config["local_steps"],
Expand Down
2 changes: 1 addition & 1 deletion examples/dynamic_layer_exchange_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)

fl.server.start_server(
server=server,
Expand Down
2 changes: 1 addition & 1 deletion examples/fedbn_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main(config: Dict[str, Any], server_address: str, dataset_name: str) -> None
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)

fl.server.start_server(
server=server,
Expand Down
2 changes: 1 addition & 1 deletion examples/feddg_ga_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main(config: Dict[str, Any]) -> None:
# will return the same sampling until it is told to reset, which in FedDgGaStrategy
# is done right before fit_round.
client_manager = FixedSamplingClientManager()
server = FlServer(strategy=strategy, client_manager=client_manager, reporters=[JsonReporter()])
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=[JsonReporter()])

fl.server.start_server(
server=server,
Expand Down
2 changes: 1 addition & 1 deletion examples/fedopt_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)

fl.server.start_server(
server_address=config["server_address"],
Expand Down
5 changes: 3 additions & 2 deletions examples/fedpca_examples/dim_reduction/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from examples.models.mnist_model import MnistNet
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServerWithCheckpointing
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand Down Expand Up @@ -63,8 +63,9 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
Expand Down
2 changes: 1 addition & 1 deletion examples/fedper_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)

fl.server.start_server(
server=server,
Expand Down
4 changes: 3 additions & 1 deletion examples/fedpm_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FedPmServer(client_manager, strategy, reset_frequency=config["priors_reset_frequency"])
server = FedPmServer(
client_manager, fl_config=config, strategy=strategy, reset_frequency=config["priors_reset_frequency"]
)

fl.server.start_server(
server=server,
Expand Down
4 changes: 3 additions & 1 deletion examples/fedprox_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def main(config: Dict[str, Any], server_address: str) -> None:
reporters = [wandb_reporter, json_reporter]
else:
reporters = [json_reporter]
server = FedProxServer(client_manager=client_manager, strategy=strategy, model=None, reporters=reporters)
server = FedProxServer(
client_manager=client_manager, fl_config=config, strategy=strategy, model=None, reporters=reporters
)

fl.server.start_server(
server=server,
Expand Down
2 changes: 1 addition & 1 deletion examples/fedrep_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)

fl.server.start_server(
server=server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.model_bases.fedsimclr_base import FedSimClrModel
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServerWithCheckpointing
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand Down Expand Up @@ -67,8 +67,9 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.model_bases.fedsimclr_base import FedSimClrModel
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServerWithCheckpointing
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand Down Expand Up @@ -66,8 +66,9 @@ def main(config: Dict[str, Any]) -> None:
initial_parameters=get_all_model_parameters(model),
)

server = FlServerWithCheckpointing(
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
Expand Down
2 changes: 1 addition & 1 deletion examples/fenda_ditto_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)

fl.server.start_server(
server=server,
Expand Down
2 changes: 1 addition & 1 deletion examples/fenda_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)

fl.server.start_server(
server=server,
Expand Down
2 changes: 1 addition & 1 deletion examples/fl_plus_local_ft_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy)

fl.server.start_server(
server_address="0.0.0.0:8080",
Expand Down
Loading