generated from VectorInstitute/aieng-template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #251 from VectorInstitute/dbe/add_feddgga_compatib…
…ility Adaptive and Fed DG-GA PFL Experimentation
- Loading branch information
Showing
75 changed files
with
4,366 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import argparse | ||
import os | ||
from logging import INFO | ||
from pathlib import Path | ||
from typing import Dict, Optional, Sequence, Tuple | ||
|
||
import flwr as fl | ||
import torch | ||
import torch.nn as nn | ||
from flwr.common.logger import log | ||
from flwr.common.typing import Config | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
||
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer | ||
from fl4health.checkpointing.client_module import ClientCheckpointModule | ||
from fl4health.clients.ditto_client import DittoClient | ||
from fl4health.utils.config import narrow_dict_type | ||
from fl4health.utils.losses import LossMeterType | ||
from fl4health.utils.metrics import F1, Accuracy, Metric | ||
from fl4health.utils.random import set_all_random_seeds | ||
from research.cifar10.model import ConvNet | ||
from research.cifar10.preprocess import get_preprocessed_data | ||
|
||
|
||
class CifarDittoClient(DittoClient): | ||
def __init__( | ||
self, | ||
data_path: Path, | ||
metrics: Sequence[Metric], | ||
device: torch.device, | ||
client_number: int, | ||
learning_rate: float, | ||
heterogeneity_level: float, | ||
loss_meter_type: LossMeterType = LossMeterType.AVERAGE, | ||
checkpointer: Optional[ClientCheckpointModule] = None, | ||
) -> None: | ||
super().__init__( | ||
data_path=data_path, | ||
metrics=metrics, | ||
device=device, | ||
loss_meter_type=loss_meter_type, | ||
checkpointer=checkpointer, | ||
) | ||
self.client_number = client_number | ||
self.heterogeneity_level = heterogeneity_level | ||
self.learning_rate: float = learning_rate | ||
|
||
log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") | ||
|
||
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: | ||
batch_size = narrow_dict_type(config, "batch_size", int) | ||
train_loader, val_loader, _ = get_preprocessed_data( | ||
self.data_path, self.client_number, batch_size, self.heterogeneity_level | ||
) | ||
return train_loader, val_loader | ||
|
||
def get_criterion(self, config: Config) -> _Loss: | ||
return torch.nn.CrossEntropyLoss() | ||
|
||
def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: | ||
global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) | ||
local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) | ||
return {"global": global_optimizer, "local": local_optimizer} | ||
|
||
def get_model(self, config: Config) -> nn.Module: | ||
return ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="FL Client Main") | ||
parser.add_argument( | ||
"--artifact_dir", | ||
action="store", | ||
type=str, | ||
help="Path to save client artifacts such as logs and model checkpoints", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--dataset_dir", | ||
action="store", | ||
type=str, | ||
help="Path to the preprocessed Cifar 10 Dataset", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--run_name", | ||
action="store", | ||
help="Name of the run, model checkpoints will be saved under a subfolder with this name", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--server_address", | ||
action="store", | ||
type=str, | ||
help="Server Address for the clients to communicate with the server through", | ||
default="0.0.0.0:8080", | ||
) | ||
parser.add_argument( | ||
"--client_number", | ||
action="store", | ||
type=int, | ||
help="Number of the client for dataset loading", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1 | ||
) | ||
parser.add_argument( | ||
"--seed", | ||
action="store", | ||
type=int, | ||
help="Seed for the random number generators across python, torch, and numpy", | ||
required=False, | ||
) | ||
parser.add_argument( | ||
"--beta", | ||
action="store", | ||
type=float, | ||
help="Heterogeneity level for the dataset", | ||
required=True, | ||
) | ||
args = parser.parse_args() | ||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
log(INFO, f"Device to be used: {DEVICE}") | ||
log(INFO, f"Server Address: {args.server_address}") | ||
log(INFO, f"Learning Rate: {args.learning_rate}") | ||
log(INFO, f"Beta: {args.beta}") | ||
|
||
# Set the random seed for reproducibility | ||
set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) | ||
|
||
# Adding extensive checkpointing for the client | ||
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) | ||
pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl" | ||
pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" | ||
post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" | ||
post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" | ||
checkpointer = ClientCheckpointModule( | ||
pre_aggregation=[ | ||
BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), | ||
LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), | ||
], | ||
post_aggregation=[ | ||
BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), | ||
LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), | ||
], | ||
) | ||
|
||
data_path = Path(args.dataset_dir) | ||
client = CifarDittoClient( | ||
data_path=data_path, | ||
metrics=[ | ||
Accuracy("accuracy"), | ||
F1("f1_score_macro", average="macro"), | ||
F1("f1_score_weight", average="weighted"), | ||
], | ||
device=DEVICE, | ||
client_number=args.client_number, | ||
learning_rate=args.learning_rate, | ||
heterogeneity_level=args.beta, | ||
checkpointer=checkpointer, | ||
) | ||
|
||
fl.client.start_client(server_address=args.server_address, client=client.to_client()) | ||
client.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Parameters that describe server | ||
n_server_rounds: 20 # The number of rounds to run FL | ||
|
||
# Parameters that describe clients | ||
n_clients: 7 # The number of clients in the FL experiment | ||
local_epochs: 1 # The number of epochs to complete for client | ||
batch_size: 32 # The batch size for client training |
Oops, something went wrong.