Skip to content

Commit

Permalink
Merge pull request #251 from VectorInstitute/dbe/add_feddgga_compatib…
Browse files Browse the repository at this point in the history
…ility

Adaptive and Fed DG-GA PFL Experimentation
  • Loading branch information
emersodb authored Nov 11, 2024
2 parents 657c400 + 93988c7 commit 9f6ac72
Show file tree
Hide file tree
Showing 75 changed files with 4,366 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(
criteria. If none, then no server-side checkpointing is performed.
Multiple checkpointers can also be passed in a sequence to checkpoint
based on multiple criteria. Defaults to None.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health
reporters which the server should send data to before and after each round.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should
send data to before and after each round.
"""
assert isinstance(
strategy, FedAvgWithAdaptiveConstraint
Expand Down
9 changes: 4 additions & 5 deletions fl4health/server/tabular_feature_alignment_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ class TabularFeatureAlignmentServer(FlServer):
strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle.
client updates and other information potentially sent by the participating clients. If None the
strategy is FedAvg as set by the flwr Server.
wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log
information and results to a Weights and Biases account. If None is provided, no logging occurs.
Defaults to None.
checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform
server side checkpointing based on some criteria. If none, then no server-side checkpointing is
performed. Defaults to None.
tab_features_source_of_truth (Optional[TabularFeaturesInfoEncoder]): The information that is required
for aligning client features. If it is not specified, then the server will randomly poll a client
and gather this information from its data source.
for aligning client features. If it is not specified, then the server will randomly poll a client
and gather this information from its data source.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should
send data to before and after each round.
"""

def __init__(
Expand Down
17 changes: 11 additions & 6 deletions fl4health/utils/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def partition_label_indices(
# Dropping the last partition as they are "excess" indices
return partitioned_indices[:-1], min_samples, partition_allocations

def partition_dataset(self, original_dataset: D, max_retries: int = 5) -> Tuple[List[D], Dict[T, np.ndarray]]:
def partition_dataset(
self, original_dataset: D, max_retries: Optional[int] = 5
) -> Tuple[List[D], Dict[T, np.ndarray]]:
"""
Attempts partitioning of the original dataset up to max_retries times. Retries are potentially required if
the user requests a minimum number of labels be assigned to each of the partitions. If the drawn Dirichlet
Expand All @@ -157,16 +159,19 @@ def partition_dataset(self, original_dataset: D, max_retries: int = 5) -> Tuple[
Args:
original_dataset (D): The dataset to be partitioned
max_retries (int, optional): Number of times to attempt to satisfy a user provided minimum
label-associated data points per partition. Defaults to 5.
max_retries (Optional[int], optional): Number of times to attempt to satisfy a user provided minimum
label-associated data points per partition. Set this value to None if you want to retry indefinitely.
Defaults to 5.
Raises:
ValueError: Throws this error if the retries have been exhausted and the user provided minimum is not met.
Returns:
List[D]: The partitioned datasets, length should correspond to self.number_of_partitions
Dict[T, np.ndarray]: The Dirichlet distribution used to partition the data points for each label.
Tuple[List[D], Dict[T, np.ndarray]]: List[D] is the partitioned datasets, length should correspond to
self.number_of_partitions. Dict[T, np.ndarray] is the Dirichlet distribution used to partition the data
points for each label.
"""

targets = original_dataset.targets
assert targets is not None, "A label-based partitioner requires targets but this dataset has no targets"
partitioned_indices = [torch.Tensor([]).int() for _ in range(self.number_of_partitions)]
Expand Down Expand Up @@ -195,7 +200,7 @@ def partition_dataset(self, original_dataset: D, max_retries: int = 5) -> Tuple[
f"minimum requested was {self.min_label_examples}. Resampling the partition..."
),
)
if partition_attempts == max_retries:
if max_retries is not None and partition_attempts >= max_retries:
raise ValueError(
(
f"Max Retries: {max_retries} reached. Partitioning failed to "
Expand Down
36 changes: 30 additions & 6 deletions fl4health/utils/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,56 @@
from flwr.common.logger import log


def set_all_random_seeds(seed: Optional[int] = 42) -> None:
"""Set seeds for python random, numpy random, and pytorch random.
def set_all_random_seeds(
seed: Optional[int] = 42, use_deterministic_torch_algos: bool = False, disable_torch_benchmarking: bool = False
) -> None:
"""
Set seeds for python random, numpy random, and pytorch random. It also offers the option to force pytorch to use
deterministic algorithm for certain methods and layers see:
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html) for more details. Finally, it
allows one to disable cuda benchmarking, which can also affect the determinism of pytorch training outside of
random seeding. For more information on reproducibility in pytorch see:
https://pytorch.org/docs/stable/notes/randomness.html
Will no-op if seed is `None`.
NOTE: If the use_deterministic_torch_algos flag is set to True, you may need to set the environment variable
CUBLAS_WORKSPACE_CONFIG to something like :4096:8, to avoid CUDA errors. Additional documentation may be found
here: https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
Args:
seed (int): The seed value to be used for random number generators. Default is 42.
seed (Optional[int], optional): The seed value to be used for random number generators. Default is 42. Seed
setting will no-op if the seed is explicitly set to None
use_deterministic_torch_algos (bool, optional): Whether or not to set torch.use_deterministic_algorithms to
True. Defaults to False.
disable_torch_benchmarking (bool, optional): Whether to explicitly disable cuda benchmarking in
torch processes. Defaults to False.
"""
if seed is None:
log(INFO, "No seed provided. Using random seed.")
else:
log(INFO, f"Setting seed to {seed}")
log(INFO, f"Setting random seeds to {seed}.")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_deterministic_torch_algos:
log(INFO, "Setting torch.use_deterministic_algorithms to True.")
# warn_only is set to true so that layers and components without deterministic algorithms available will
# warn the user that they don't exist, but won't take down the process with an exception.
torch.use_deterministic_algorithms(True, warn_only=True)
if disable_torch_benchmarking:
log(INFO, "Disabling CUDA algorithm benchmarking.")
torch.backends.cudnn.benchmark = False


def unset_all_random_seeds() -> None:
"""
Set random seeds for Python random, NumPy, and PyTorch to None. Running this function would undo
the effects of set_all_random_seeds.
"""
log(INFO, "Setting all random seeds to None.")
log(INFO, "Setting all random seeds to None. Reverting torch determinism settings")
random.seed(None)
np.random.seed(None)
torch.seed()
torch.use_deterministic_algorithms(False)


def generate_hash(length: int = 8) -> str:
Expand Down
Empty file added research/cifar10/__init__.py
Empty file.
Empty file.
Empty file.
168 changes: 168 additions & 0 deletions research/cifar10/adaptive_pfl/ditto/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import argparse
import os
from logging import INFO
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.logger import log
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.clients.ditto_client import DittoClient
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import F1, Accuracy, Metric
from fl4health.utils.random import set_all_random_seeds
from research.cifar10.model import ConvNet
from research.cifar10.preprocess import get_preprocessed_data


class CifarDittoClient(DittoClient):
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
client_number: int,
learning_rate: float,
heterogeneity_level: float,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
)
self.client_number = client_number
self.heterogeneity_level = heterogeneity_level
self.learning_rate: float = learning_rate

log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}")

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = get_preprocessed_data(
self.data_path, self.client_number, batch_size, self.heterogeneity_level
)
return train_loader, val_loader

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate)
local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)
return {"global": global_optimizer, "local": local_optimizer}

def get_model(self, config: Config) -> nn.Module:
return ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FL Client Main")
parser.add_argument(
"--artifact_dir",
action="store",
type=str,
help="Path to save client artifacts such as logs and model checkpoints",
required=True,
)
parser.add_argument(
"--dataset_dir",
action="store",
type=str,
help="Path to the preprocessed Cifar 10 Dataset",
required=True,
)
parser.add_argument(
"--run_name",
action="store",
help="Name of the run, model checkpoints will be saved under a subfolder with this name",
required=True,
)
parser.add_argument(
"--server_address",
action="store",
type=str,
help="Server Address for the clients to communicate with the server through",
default="0.0.0.0:8080",
)
parser.add_argument(
"--client_number",
action="store",
type=int,
help="Number of the client for dataset loading",
required=True,
)
parser.add_argument(
"--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1
)
parser.add_argument(
"--seed",
action="store",
type=int,
help="Seed for the random number generators across python, torch, and numpy",
required=False,
)
parser.add_argument(
"--beta",
action="store",
type=float,
help="Heterogeneity level for the dataset",
required=True,
)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Beta: {args.beta}")

# Set the random seed for reproducibility
set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True)

# Adding extensive checkpointing for the client
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl"
pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl"
post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl"
post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl"
checkpointer = ClientCheckpointModule(
pre_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name),
],
post_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name),
],
)

data_path = Path(args.dataset_dir)
client = CifarDittoClient(
data_path=data_path,
metrics=[
Accuracy("accuracy"),
F1("f1_score_macro", average="macro"),
F1("f1_score_weight", average="weighted"),
],
device=DEVICE,
client_number=args.client_number,
learning_rate=args.learning_rate,
heterogeneity_level=args.beta,
checkpointer=checkpointer,
)

fl.client.start_client(server_address=args.server_address, client=client.to_client())
client.shutdown()
7 changes: 7 additions & 0 deletions research/cifar10/adaptive_pfl/ditto/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Parameters that describe server
n_server_rounds: 20 # The number of rounds to run FL

# Parameters that describe clients
n_clients: 7 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client
batch_size: 32 # The batch size for client training
Loading

0 comments on commit 9f6ac72

Please sign in to comment.