Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidating model and state checkpointing on the client and server sides. #298

Merged
merged 22 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4c15c93
WIP
emersodb Nov 21, 2024
5f04c82
Merge branch 'dbe/server_stores_config' into dbe/some_server_side_che…
emersodb Nov 22, 2024
7cba938
WIP checkin to preserve work, not running checks
emersodb Nov 25, 2024
ef2e1a7
Merge branch 'dbe/server_stores_config' into dbe/some_server_side_che…
emersodb Nov 25, 2024
12ff1bb
Committing an initial full migration to the state checkpointing modules
emersodb Nov 26, 2024
eeaf364
Fixing precommit issues/mypy stuff.
emersodb Nov 27, 2024
f93c03b
Fixing issues with the unit tests
emersodb Nov 27, 2024
fd6b34e
Fixing some small smoke test issues.
emersodb Nov 27, 2024
a371595
Merge branch 'dbe/server_stores_config' into dbe/some_server_side_che…
emersodb Nov 27, 2024
aae4faf
Merge branch 'dbe/server_stores_config' into dbe/some_server_side_che…
emersodb Nov 27, 2024
bf3a7bf
Merge branch 'dbe/server_stores_config' into dbe/some_server_side_che…
emersodb Dec 2, 2024
2a4c2a2
Fixing small issue introduced in merge
emersodb Dec 3, 2024
b2227e4
Merge branch 'main' into dbe/some_server_side_checkpointer_consolidation
emersodb Jan 6, 2025
9732abd
Renaming the type aliases to be better type representations
emersodb Jan 6, 2025
d198452
Merge remote-tracking branch 'origin/dbe/fixing_jinja_vulnerability' …
emersodb Jan 6, 2025
0094efb
Merge branch 'main' into dbe/some_server_side_checkpointer_consolidation
emersodb Jan 6, 2025
5cf2df6
Avoiding some code duplication
emersodb Jan 6, 2025
a25aae4
A few small updates from John J.'s PR suggestions.
emersodb Jan 7, 2025
a4ad4d4
Merge branch 'main' into dbe/some_server_side_checkpointer_consolidation
emersodb Jan 7, 2025
e6ab253
Grammatical correction to documentation
emersodb Jan 8, 2025
7066e70
Changing ubuntu to latest to see if its fixed
emersodb Jan 8, 2025
17ef402
Still a problem
emersodb Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from flwr.server.strategy import FedAvg

from examples.models.mnist_model import MnistNet
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -47,7 +48,10 @@ def main(config: Dict[str, Any]) -> None:
model = MnistNet(int(config["latent_dim"]) * 2)
# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -66,10 +70,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
12 changes: 7 additions & 5 deletions examples/ae_examples/cvae_examples/conv_cvae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from flwr.server.strategy import FedAvg

from examples.ae_examples.cvae_examples.conv_cvae_example.models import ConvConditionalDecoder, ConvConditionalEncoder
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
Expand Down Expand Up @@ -48,7 +49,10 @@ def main(config: Dict[str, Any]) -> None:

# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -67,10 +71,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
12 changes: 7 additions & 5 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from flwr.server.strategy import FedAvg

from examples.ae_examples.cvae_examples.mlp_cvae_example.models import MnistConditionalDecoder, MnistConditionalEncoder
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
Expand Down Expand Up @@ -48,7 +49,10 @@ def main(config: Dict[str, Any]) -> None:

# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -67,10 +71,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
15 changes: 8 additions & 7 deletions examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from flwr.server.client_manager import SimpleClientManager

from examples.ae_examples.fedprox_vae_example.models import MnistVariationalDecoder, MnistVariationalEncoder
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule
from fl4health.model_bases.autoencoders_base import VariationalAe
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -47,8 +47,11 @@ def main(config: Dict[str, Any]) -> None:
model_checkpoint_name = "best_VAE_model.pkl"

# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name)

checkpoint_and_state_module = AdaptiveConstraintServerCheckpointAndStateModule(
model=model, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy and potentially adapts the
# FedProx proximal weight mu
Expand All @@ -70,10 +73,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
4 changes: 2 additions & 2 deletions examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main(config: Dict[str, Any]) -> None:
local_steps=config.get("local_steps"),
)

initial_model = ApflModule(MnistNetWithBnAndFrozen())
model = ApflModule(MnistNetWithBnAndFrozen())

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -56,7 +56,7 @@ def main(config: Dict[str, Any]) -> None:
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_all_model_parameters(initial_model),
initial_parameters=get_all_model_parameters(model),
)

client_manager = SimpleClientManager()
Expand Down
14 changes: 8 additions & 6 deletions examples/basic_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from examples.models.cnn_model import Net
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -44,9 +45,12 @@ def main(config: Dict[str, Any]) -> None:
# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointers = [
BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl"),
LatestTorchCheckpointer(config["checkpoint_path"], "latest_model.pkl"),
BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"),
LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"),
]
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointers
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -65,10 +69,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointers,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
24 changes: 15 additions & 9 deletions examples/docker_basic_example/fl_client/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import argparse
from pathlib import Path
from typing import Sequence
from typing import Sequence, Tuple

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from examples.models.cnn_model import Net
from fl4health.clients.basic_client import BasicClient
Expand All @@ -20,17 +24,19 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.dev
self.model = Net()
self.parameter_exchanger = FullParameterExchanger()

def setup_client(self, config: Config) -> None:
super().setup_client(config)
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, validation_loader, num_examples = load_cifar10_data(self.data_path, batch_size)
train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size)
return train_loader, val_loader

self.train_loader = train_loader
self.val_loader = validation_loader
self.num_examples = num_examples
def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def get_model(self, config: Config) -> nn.Module:
return Net().to(self.device)


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions examples/dp_fed_examples/instance_level_dp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import DataLoader

from examples.models.cnn_model import Net
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer
from fl4health.clients.instance_level_dp_client import InstanceLevelDpClient
from fl4health.utils.config import narrow_dict_type
Expand Down Expand Up @@ -48,12 +48,14 @@ def get_criterion(self, config: Config) -> _Loss:
post_aggregation_checkpointer = BestLossOpacusCheckpointer(
checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name
)
checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer)
checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer)

# Load model and data
data_path = Path(args.dataset_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
client = CifarClient(data_path, [Accuracy("accuracy")], device, checkpointer=checkpointer)
client = CifarClient(
data_path, [Accuracy("accuracy")], device, checkpoint_and_state_module=checkpoint_and_state_module
)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

client.shutdown()
22 changes: 13 additions & 9 deletions examples/dp_fed_examples/instance_level_dp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from examples.models.cnn_model import Net
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer
from fl4health.checkpointing.server_module import OpacusServerCheckpointAndStateModule
from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.instance_level_dp_server import InstanceLevelDpServer
Expand Down Expand Up @@ -67,15 +68,24 @@ def main(config: Dict[str, Any]) -> None:
local_steps=config.get("local_steps"),
)

initial_model = map_model_to_opacus_model(Net())
model = map_model_to_opacus_model(Net())

client_name = "".join(choices(string.ascii_uppercase, k=5))
checkpoint_dir = "examples/dp_fed_examples/instance_level_dp/"
checkpoint_name = f"server_{client_name}_best_model.pkl"
checkpointer = BestLossOpacusCheckpointer(checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name)

checkpoint_and_state_module = OpacusServerCheckpointAndStateModule(
model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer
)

# ClientManager that performs Poisson type sampling
client_manager = PoissonSamplingClientManager()

# Server performs simple FedAveraging with Instance Level Differential Privacy
# Must be FedAvg sampling to ensure privacy loss is computed correctly
strategy = OpacusBasicFedAvg(
model=initial_model,
model=model,
fraction_fit=config["client_sampling_rate"],
# Server waits for min_available_clients before starting FL rounds
min_available_clients=config["n_clients"],
Expand All @@ -86,22 +96,16 @@ def main(config: Dict[str, Any]) -> None:
on_evaluate_config_fn=fit_config_fn,
)

client_name = "".join(choices(string.ascii_uppercase, k=5))
checkpoint_dir = "examples/dp_fed_examples/instance_level_dp/"
checkpoint_name = f"server_{client_name}_best_model.pkl"

server = InstanceLevelDpServer(
client_manager=client_manager,
fl_config=config,
model=initial_model,
checkpointer=BestLossOpacusCheckpointer(checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name),
parameter_exchanger=FullParameterExchanger(),
strategy=strategy,
noise_multiplier=config["noise_multiplier"],
local_epochs=config.get("local_epochs"),
local_steps=config.get("local_steps"),
batch_size=config["batch_size"],
num_server_rounds=config["n_server_rounds"],
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
6 changes: 3 additions & 3 deletions examples/fedopt_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from examples.fedopt_example.client_data import LabelEncoder, Vocabulary, construct_dataloaders
from examples.fedopt_example.metrics import CompoundMetric
from examples.models.lstm_model import LSTM
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.basic_client import BasicClient, TorchInputType
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.losses import LossMeterType
Expand All @@ -27,9 +27,9 @@ def __init__(
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None,
) -> None:
super().__init__(data_path, metrics, device, loss_meter_type, checkpointer)
super().__init__(data_path, metrics, device, loss_meter_type, checkpoint_and_state_module)
self.weight_matrix: torch.Tensor
self.vocabulary: Vocabulary
self.label_encoder: LabelEncoder
Expand Down
12 changes: 7 additions & 5 deletions examples/fedpca_examples/dim_reduction/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from flwr.server.strategy import FedAvg

from examples.models.mnist_model import MnistNet
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -47,7 +48,10 @@ def main(config: Dict[str, Any]) -> None:
parameter_exchanger = FullParameterExchanger()

# To facilitate checkpointing
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -66,10 +70,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
4 changes: 1 addition & 3 deletions examples/fedprox_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
reporters = [wandb_reporter, json_reporter]
else:
reporters = [json_reporter]
server = FedProxServer(
client_manager=client_manager, fl_config=config, strategy=strategy, model=None, reporters=reporters
)
server = FedProxServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters)

fl.server.start_server(
server=server,
Expand Down
Loading
Loading