From 5a56679d07cd948af243c67a3ebccb6b5a8cfed2 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:32:26 -0400 Subject: [PATCH 01/19] Small change --- fl4health/strategies/feddg_ga_strategy.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/fl4health/strategies/feddg_ga_strategy.py b/fl4health/strategies/feddg_ga_strategy.py index 3d47b34be..e511314f6 100644 --- a/fl4health/strategies/feddg_ga_strategy.py +++ b/fl4health/strategies/feddg_ga_strategy.py @@ -252,9 +252,19 @@ def aggregate_fit( (Tuple[Optional[Parameters], Dict[str, Scalar]]) A tuple containing the aggregated parameters and the aggregated fit metrics. """ - # The original aggregated parameters is done by the super class (which we want to - # override its behaviour here), so we are discarding it to recalculate them in the lines below - _, metrics_aggregated = super().aggregate_fit(server_round, results, failures) + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") self.train_metrics = {} for client_proxy, fit_res in results: From ed2d08b302d90eb567fe29470d1088436b243184 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 2 Oct 2024 13:40:22 -0400 Subject: [PATCH 02/19] Setting up preliminary experimental folder structure and preprocessing scripts --- examples/feddg_ga_example/config.yaml | 3 +- research/cifar10/__init__.py | 0 research/cifar10/adaptive_pfl/__init__.py | 0 .../data_preprocess_scripts/preprocess.slrm | 72 +++++++ .../data_preprocess_scripts/preprocess_all.sh | 36 ++++ .../cifar10/adaptive_pfl/ditto/__init__.py | 0 .../adaptive_pfl/fenda_ditto/__init__.py | 0 .../cifar10/adaptive_pfl/mrmtl/__init__.py | 0 research/cifar10/fed_dgga_pfl/__init__.py | 0 .../cifar10/fed_dgga_pfl/ditto/__init__.py | 0 .../cifar10/fed_dgga_pfl/fenda/__init__.py | 0 .../fed_dgga_pfl/fenda_ditto/__init__.py | 0 research/cifar10/model.py | 47 +++++ research/cifar10/preprocess.py | 185 ++++++++++++++++++ 14 files changed, 342 insertions(+), 1 deletion(-) create mode 100644 research/cifar10/__init__.py create mode 100644 research/cifar10/adaptive_pfl/__init__.py create mode 100644 research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess.slrm create mode 100644 research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess_all.sh create mode 100644 research/cifar10/adaptive_pfl/ditto/__init__.py create mode 100644 research/cifar10/adaptive_pfl/fenda_ditto/__init__.py create mode 100644 research/cifar10/adaptive_pfl/mrmtl/__init__.py create mode 100644 research/cifar10/fed_dgga_pfl/__init__.py create mode 100644 research/cifar10/fed_dgga_pfl/ditto/__init__.py create mode 100644 research/cifar10/fed_dgga_pfl/fenda/__init__.py create mode 100644 research/cifar10/fed_dgga_pfl/fenda_ditto/__init__.py create mode 100644 research/cifar10/model.py create mode 100644 research/cifar10/preprocess.py diff --git a/examples/feddg_ga_example/config.yaml b/examples/feddg_ga_example/config.yaml index 50c7aa8c0..992143f44 100644 --- a/examples/feddg_ga_example/config.yaml +++ b/examples/feddg_ga_example/config.yaml @@ -5,4 +5,5 @@ n_server_rounds: 3 # The number of rounds to run FL n_clients: 2 # The number of clients in the FL experiment local_steps: 5 # The number of local steps (one per batch) to complete for client batch_size: 128 # The batch size for client training -evaluate_after_fit: True # Evaluates model immediately after local training on the validation set (in addition to the training set) +# Evaluates model immediately after local training on the validation set (in addition to the training set) +evaluate_after_fit: True diff --git a/research/cifar10/__init__.py b/research/cifar10/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/adaptive_pfl/__init__.py b/research/cifar10/adaptive_pfl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess.slrm b/research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess.slrm new file mode 100644 index 000000000..ca04f463d --- /dev/null +++ b/research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess.slrm @@ -0,0 +1,72 @@ +#!/bin/bash + +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:0 +#SBATCH --mem=32G +#SBATCH --partition=cpu +#SBATCH --qos=normal +#SBATCH --job-name=cifar_dirichlet_allocation +#SBATCH --output=%j_%x.out +#SBATCH --error=%j_%x.err + +# Note: +# ntasks: Total number of processes to use across world +# ntasks-per-node: How many processes each node should create + +# Set NCCL options +# export NCCL_DEBUG=INFO +# NCCL backend to communicate between GPU workers is not provided in vector's cluster. +# Disable this option in slurm. +export NCCL_IB_DISABLE=1 + +if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ + [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then + echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}" + export NCCL_SOCKET_IFNAME=bond0 +fi + +# Process inputs +VENV_PATH=$1 +DATASET_DIR=$2 +OUTPUT_DIR=$3 +SEED=$4 +BETA=$5 +NUM_PARTITIONS=$6 +LOG_DIR=$7 + +echo "Python Venv Path: ${VENV_PATH}" +echo "CIFAR Dataset Path: ${DATASET_DIR}" +echo "Output for Partitions: ${OUTPUT_DIR}" +echo "Reproducibility Seed: ${SEED}" +echo "Dirichlet Beta: ${BETA}" +echo "Number of partitions to produce: ${NUM_PARTITIONS}" +echo "Logs being placed in: ${LOG_DIR}" + +SERVER_ADDRESS="${SLURMD_NODENAME}:${SERVER_PORT}" + +echo "Server Address: ${SERVER_ADDRESS}" + +LOG_PATH="${LOG_DIR}preprocess_${BETA}_${NUM_PARTITIONS}_${SEED}.log" + +echo "World size: ${SLURM_NTASKS}" +echo "Number of nodes: ${SLURM_NNODES}" +NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +echo "GPUs per node: ${NUM_GPUs}" + +# Source the environment +source ${VENV_PATH}bin/activate +echo "Active Environment" +which python + +nohup python -m research.cifar10.preprocess.py \ + --dataset_dir ${DATASET_DIR} \ + --save_dataset_dir ${OUTPUT_DIR} \ + --seed ${SEED} \ + --beta ${BETA} \ + --num_clients ${NUM_PARTITIONS} \ + > ${LOG_PATH} 2>&1 + +echo "Done" diff --git a/research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess_all.sh b/research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess_all.sh new file mode 100644 index 000000000..bc28cd567 --- /dev/null +++ b/research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess_all.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +VENV_PATH=$1 + +SEEDS=( 2024 2025 2026 ) +BETAS=( 0.1 0.5 5.0 ) +NUM_PARTITIONS=( 7 7 7 ) + +ORIGINAL_DATA_DIR="PLACEHOLDER" +DESTINATION_DIRS=( \ + "DEST1" \ + "DEST2" \ + "DEST3" \ + ) + +echo "Python Venv Path: ${VENV_PATH}" + +for index in "${!DESTINATION_DIRS[@]}"; +do + + echo "Preprocessing CIFAR with SEED: ${SEEDS[index]} and BETA: ${BETAS[index]}" + echo "Number of partitions: ${NUM_PARTITIONS[index]}" + echo "Destination of partitions: ${DESTINATION_DIRS[index]}" + + CLIENT_OUT_LOGS="cifar_preprocess_log_${SEEDS}_${BETAS}_${NUM_PARTITIONS}.out" + CLIENT_ERROR_LOGS="cifar_preprocess_log_${SEEDS}_${BETAS}_${NUM_PARTITIONS}.err" + + SBATCH_COMMAND="--job-name=cifar_preprocess_${BETA} --output=${CLIENT_OUT_LOGS} --error=${CLIENT_ERROR_LOGS} \ + research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess.slrm \ + ${VENV_PATH} ${ORIGINAL_DATA_DIR} ${DESTINATION_DIRS[index]} ${SEEDS[index]} ${BETAS[index]} \ + ${NUM_PARTITIONS[index]} ${DESTINATION_DIRS[index]}" \ + + sbatch ${SBATCH_COMMAND} +done + +echo "Preprocess Jobs Launched" diff --git a/research/cifar10/adaptive_pfl/ditto/__init__.py b/research/cifar10/adaptive_pfl/ditto/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/__init__.py b/research/cifar10/adaptive_pfl/fenda_ditto/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/adaptive_pfl/mrmtl/__init__.py b/research/cifar10/adaptive_pfl/mrmtl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/fed_dgga_pfl/__init__.py b/research/cifar10/fed_dgga_pfl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/fed_dgga_pfl/ditto/__init__.py b/research/cifar10/fed_dgga_pfl/ditto/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/fed_dgga_pfl/fenda/__init__.py b/research/cifar10/fed_dgga_pfl/fenda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/__init__.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/model.py b/research/cifar10/model.py new file mode 100644 index 000000000..287d2a570 --- /dev/null +++ b/research/cifar10/model.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +from torch.nn import BatchNorm2d, Conv2d, Flatten, Linear, MaxPool2d, Module, ReLU + + +class ConvNet(Module): + + def __init__( + self, + in_channels: int, + h: int = 32, + w: int = 32, + hidden: int = 2048, + class_num: int = 10, + use_bn: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + + self.conv1 = Conv2d(in_channels, 32, 5, padding=2) + self.conv2 = Conv2d(32, 64, 5, padding=2) + self.use_bn = use_bn + if use_bn: + self.bn1 = BatchNorm2d(32) + self.bn2 = BatchNorm2d(64) + + self.fc1 = Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden) + self.fc2 = Linear(hidden, class_num) + + self.relu = ReLU(inplace=True) + self.maxpool = MaxPool2d(2) + self.dropout_layer = nn.Dropout(p=dropout) + self.flatten = Flatten() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + x = self.bn1(self.conv1(x)) if self.use_bn else self.conv1(x) + x = self.maxpool(self.relu(x)) + x = self.bn2(self.conv2(x)) if self.use_bn else self.conv2(x) + x = self.maxpool(self.relu(x)) + x = self.flatten(x) + x = self.dropout_layer(x) + x = self.relu(self.fc1(x)) + x = self.dropout_layer(x) + x = self.fc2(x) + + return x diff --git a/research/cifar10/preprocess.py b/research/cifar10/preprocess.py new file mode 100644 index 000000000..b637af0ff --- /dev/null +++ b/research/cifar10/preprocess.py @@ -0,0 +1,185 @@ +import argparse +import os +from logging import INFO +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torchvision.transforms as transforms +from flwr.common.logger import log +from torch.utils.data import DataLoader + +from fl4health.utils.dataset import TensorDataset +from fl4health.utils.load_data import ToNumpy, get_cifar10_data_and_target_tensors, split_data_and_targets +from fl4health.utils.partitioners import DirichletLabelBasedAllocation +from fl4health.utils.random import set_all_random_seeds + + +def get_preprocessed_data( + dataset_dir: Path, client_num: int, batch_size: int, beta: float +) -> Tuple[DataLoader, DataLoader, Dict[str, int]]: + transform = transforms.Compose( + [ + ToNumpy(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + try: + train_data = torch.from_numpy(np.load(f"{dataset_dir}/beta_{beta}/client_{client_num}_train_data.npy")) + train_targets = torch.from_numpy(np.load(f"{dataset_dir}/beta_{beta}/client_{client_num}_train_targets.npy")) + except FileNotFoundError: + raise FileNotFoundError(f"Client {client_num} does not have partitioned train data") + + training_set = TensorDataset(train_data, train_targets, transform=transform, target_transform=None) + + try: + validation_data = torch.from_numpy( + np.load(f"{dataset_dir}/beta_{beta}/client_{client_num}_validation_data.npy") + ) + validation_targets = torch.from_numpy( + np.load(f"{dataset_dir}/beta_{beta}/client_{client_num}_validation_targets.npy") + ) + except FileNotFoundError: + raise FileNotFoundError(f"Client {client_num} does not have partitioned validation data") + + validation_set = TensorDataset(validation_data, validation_targets, transform=transform, target_transform=None) + + train_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True) + validation_loader = DataLoader(validation_set, batch_size=batch_size) + num_examples = { + "train_set": len(training_set), + "validation_set": len(validation_set), + } + + return train_loader, validation_loader, num_examples + + +def get_test_preprocessed_data( + dataset_dir: Path, client_num: int, batch_size: int, beta: float +) -> Tuple[DataLoader, Dict[str, int]]: + transform = transforms.Compose( + [ + ToNumpy(), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + try: + data = torch.from_numpy(np.load(f"{dataset_dir}/beta_{beta}/client_{client_num}_test_data.npy")) + targets = torch.from_numpy(np.load(f"{dataset_dir}/beta_{beta}/client_{client_num}_test_targets.npy")) + except FileNotFoundError: + raise FileNotFoundError(f"Client {client_num} does not have partitioned test data") + + evaluation_set = TensorDataset(data, targets, transform=transform, target_transform=None) + + evaluation_loader = DataLoader(evaluation_set, batch_size=batch_size, shuffle=False) + num_examples = {"eval_set": len(evaluation_set)} + + return evaluation_loader, num_examples + + +def preprocess_data( + dataset_dir: Path, num_clients: int, beta: float +) -> Tuple[List[TensorDataset], List[TensorDataset], List[TensorDataset]]: + # Get raw data + data, targets = get_cifar10_data_and_target_tensors(dataset_dir, True) + + train_data, train_targets, val_data, val_targets = split_data_and_targets( + data, + targets, + validation_proportion=0.1, + ) + + training_set = TensorDataset(train_data, train_targets, transform=None, target_transform=None) + validation_set = TensorDataset(val_data, val_targets, transform=None, target_transform=None) + + test_data, test_targets = get_cifar10_data_and_target_tensors(dataset_dir, False) + test_set = TensorDataset(test_data, test_targets, transform=None, target_transform=None) + + # Partition train data + heterogeneous_partitioner = DirichletLabelBasedAllocation( + number_of_partitions=num_clients, unique_labels=list(range(10)), beta=beta, min_label_examples=2 + ) + train_partitioned_datasets, train_partitioned_dist = heterogeneous_partitioner.partition_dataset( + training_set, max_retries=5 + ) + + # Partition validation and test data + heterogeneous_partitioner_with_prior = DirichletLabelBasedAllocation( + number_of_partitions=num_clients, unique_labels=list(range(10)), prior_distribution=train_partitioned_dist + ) + validation_partitioned_datasets, _ = heterogeneous_partitioner_with_prior.partition_dataset( + validation_set, max_retries=5 + ) + test_partitioned_datasets, _ = heterogeneous_partitioner_with_prior.partition_dataset(test_set, max_retries=5) + + return train_partitioned_datasets, validation_partitioned_datasets, test_partitioned_datasets + + +def save_preprocessed_data( + save_dataset_dir: Path, partitioned_datasets: List[TensorDataset], beta: float, mode: str +) -> None: + save_dir_path = f"{save_dataset_dir}/beta_{beta}" + os.makedirs(save_dir_path, exist_ok=True) + + for client in range(len(partitioned_datasets)): + save_data = partitioned_datasets[client].data + save_targets = partitioned_datasets[client].targets + np.save(f"{save_dir_path}/client_{client}_{mode}_data.npy", save_data.numpy()) + if save_targets is not None: + np.save(f"{save_dir_path}/client_{client}_{mode}_targets.npy", save_targets.numpy()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Client Main") + parser.add_argument( + "--dataset_dir", + action="store", + type=str, + help="Path to the raw Cifar 10 Dataset", + required=True, + ) + parser.add_argument( + "--save_dataset_dir", + action="store", + type=str, + help="Path to save the preprocessed Cifar 10 Dataset", + required=True, + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--beta", + action="store", + type=float, + help="Heterogeneity level for the dataset", + required=True, + ) + parser.add_argument( + "--num_clients", + action="store", + type=int, + help="Number of clients to partition the dataset into", + default=5, + ) + args = parser.parse_args() + log(INFO, f"Seed: {args.seed}") + log(INFO, f"Beta: {args.beta}") + log(INFO, f"Number of clients: {args.num_clients}") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + train_partitioned_datasets, validation_partitioned_datasets, test_partitioned_datasets = preprocess_data( + Path(args.dataset_dir), args.num_clients, args.beta + ) + save_preprocessed_data(Path(args.save_dataset_dir), train_partitioned_datasets, args.beta, "train") + save_preprocessed_data(Path(args.save_dataset_dir), validation_partitioned_datasets, args.beta, "validation") + save_preprocessed_data(Path(args.save_dataset_dir), test_partitioned_datasets, args.beta, "test") From 5883961f4e31052091185d70eb35b9ab67bc92e7 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:16:03 -0400 Subject: [PATCH 03/19] Setting up the adaptive pfl experiments for Ditto, FedProx, Fenda+Ditto, and MR-MTL. This involves pulling in some code associated with Sana's CIFAR-10 experiments as well. So that code has been added too. --- examples/fenda_ditto_example/client.py | 6 +- fl4health/clients/fenda_ditto_client.py | 8 +- fl4health/utils/metrics.py | 2 +- fl4health/utils/sampler.py | 33 +- research/cifar10/adaptive_pfl/ditto/client.py | 164 ++++ .../cifar10/adaptive_pfl/ditto/config.yaml | 7 + .../ditto/run_fold_experiment.slrm | 179 +++++ .../adaptive_pfl/ditto/run_hp_sweep.sh | 76 ++ research/cifar10/adaptive_pfl/ditto/server.py | 147 ++++ .../cifar10/adaptive_pfl/fedprox/__init__.py | 0 .../cifar10/adaptive_pfl/fedprox/client.py | 168 ++++ .../cifar10/adaptive_pfl/fedprox/config.yaml | 7 + .../fedprox/run_fold_experiment.slrm | 183 +++++ .../adaptive_pfl/fedprox/run_hp_sweep.sh | 76 ++ .../cifar10/adaptive_pfl/fedprox/server.py | 159 ++++ .../adaptive_pfl/fenda_ditto/client.py | 179 +++++ .../adaptive_pfl/fenda_ditto/config.yaml | 7 + .../fenda_ditto/run_fold_experiment.slrm | 198 +++++ .../adaptive_pfl/fenda_ditto/run_hp_sweep.sh | 84 ++ .../adaptive_pfl/fenda_ditto/server.py | 147 ++++ research/cifar10/adaptive_pfl/mrmtl/client.py | 162 ++++ .../cifar10/adaptive_pfl/mrmtl/config.yaml | 7 + .../mrmtl/run_fold_experiment.slrm | 179 +++++ .../adaptive_pfl/mrmtl/run_hp_sweep.sh | 76 ++ research/cifar10/adaptive_pfl/mrmtl/server.py | 147 ++++ research/cifar10/evaluate_on_test.py | 748 ++++++++++++++++++ research/cifar10/find_best_hp.py | 60 ++ research/cifar10/model.py | 106 ++- research/cifar10/personal_server.py | 63 ++ .../preprocess.slrm | 0 .../preprocess_all.sh | 2 +- research/cifar10/utils.py | 94 +++ 32 files changed, 3459 insertions(+), 15 deletions(-) create mode 100644 research/cifar10/adaptive_pfl/ditto/client.py create mode 100644 research/cifar10/adaptive_pfl/ditto/config.yaml create mode 100644 research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm create mode 100755 research/cifar10/adaptive_pfl/ditto/run_hp_sweep.sh create mode 100644 research/cifar10/adaptive_pfl/ditto/server.py create mode 100644 research/cifar10/adaptive_pfl/fedprox/__init__.py create mode 100644 research/cifar10/adaptive_pfl/fedprox/client.py create mode 100644 research/cifar10/adaptive_pfl/fedprox/config.yaml create mode 100644 research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm create mode 100755 research/cifar10/adaptive_pfl/fedprox/run_hp_sweep.sh create mode 100644 research/cifar10/adaptive_pfl/fedprox/server.py create mode 100644 research/cifar10/adaptive_pfl/fenda_ditto/client.py create mode 100644 research/cifar10/adaptive_pfl/fenda_ditto/config.yaml create mode 100644 research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm create mode 100755 research/cifar10/adaptive_pfl/fenda_ditto/run_hp_sweep.sh create mode 100644 research/cifar10/adaptive_pfl/fenda_ditto/server.py create mode 100644 research/cifar10/adaptive_pfl/mrmtl/client.py create mode 100644 research/cifar10/adaptive_pfl/mrmtl/config.yaml create mode 100644 research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm create mode 100755 research/cifar10/adaptive_pfl/mrmtl/run_hp_sweep.sh create mode 100644 research/cifar10/adaptive_pfl/mrmtl/server.py create mode 100644 research/cifar10/evaluate_on_test.py create mode 100644 research/cifar10/find_best_hp.py create mode 100644 research/cifar10/personal_server.py rename research/cifar10/{adaptive_pfl/data_preprocess_scripts => pfl_preprocess_scripts}/preprocess.slrm (100%) rename research/cifar10/{adaptive_pfl/data_preprocess_scripts => pfl_preprocess_scripts}/preprocess_all.sh (92%) create mode 100644 research/cifar10/utils.py diff --git a/examples/fenda_ditto_example/client.py b/examples/fenda_ditto_example/client.py index ba7123365..6fe878721 100644 --- a/examples/fenda_ditto_example/client.py +++ b/examples/fenda_ditto_example/client.py @@ -21,7 +21,7 @@ from fl4health.clients.fenda_ditto_client import FendaDittoClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.parallel_split_models import ParallelFeatureJoinMode -from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel +from fl4health.model_bases.sequential_split_models import SequentiallySplitModel from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data from fl4health.utils.metrics import Accuracy @@ -36,8 +36,8 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) return train_loader, val_loader - def get_global_model(self, config: Config) -> SequentiallySplitExchangeBaseModel: - return SequentiallySplitExchangeBaseModel( + def get_global_model(self, config: Config) -> SequentiallySplitModel: + return SequentiallySplitModel( base_module=SequentialGlobalFeatureExtractorMnist(), head_module=SequentialLocalPredictionHeadMnist(), ).to(self.device) diff --git a/fl4health/clients/fenda_ditto_client.py b/fl4health/clients/fenda_ditto_client.py index 69abbe4db..6bf063e5a 100644 --- a/fl4health/clients/fenda_ditto_client.py +++ b/fl4health/clients/fenda_ditto_client.py @@ -9,7 +9,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointModule from fl4health.clients.ditto_client import DittoClient from fl4health.model_bases.fenda_base import FendaModel -from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel +from fl4health.model_bases.sequential_split_models import SequentiallySplitModel from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking from fl4health.reporting.metrics import MetricsReporter from fl4health.utils.losses import LossMeterType, TrainingLosses @@ -89,7 +89,7 @@ def __init__( metrics_reporter=metrics_reporter, progress_bar=progress_bar, ) - self.global_model: SequentiallySplitExchangeBaseModel + self.global_model: SequentiallySplitModel self.model: FendaModel self.freeze_global_feature_extractor = freeze_global_feature_extractor @@ -108,7 +108,7 @@ def get_model(self, config: Config) -> FendaModel: """ raise NotImplementedError("This function must be defined in the inheriting class to use this client") - def get_global_model(self, config: Config) -> SequentiallySplitExchangeBaseModel: + def get_global_model(self, config: Config) -> SequentiallySplitModel: """ User defined method that returns a Global Sequential Model that is compatible with the local FENDA model. @@ -116,7 +116,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitExchangeBaseModel config (Config): The config from the server. Returns: - SequentiallySplitExchangeBaseModel: The global (Ditto) model. + SequentiallySplitModel: The global (Ditto) model. Raises: NotImplementedError: To be defined in child class. diff --git a/fl4health/utils/metrics.py b/fl4health/utils/metrics.py index 7a33773ce..642dd8c1c 100644 --- a/fl4health/utils/metrics.py +++ b/fl4health/utils/metrics.py @@ -336,7 +336,7 @@ def __init__( average: Optional[str] = "weighted", ): """ - Computes the F1 score using the sklearn f1_score function. As such, the values of average are correspond to + Computes the F1 score using the sklearn f1_score function. As such, the values of average correspond to those of that function. Args: diff --git a/fl4health/utils/sampler.py b/fl4health/utils/sampler.py index 894211392..6dddd5975 100644 --- a/fl4health/utils/sampler.py +++ b/fl4health/utils/sampler.py @@ -1,9 +1,11 @@ import math from abc import ABC, abstractmethod -from typing import Any, List, Set, TypeVar, Union +from logging import INFO +from typing import Any, List, Optional, Set, TypeVar, Union import numpy as np import torch +from flwr.common.logger import log from fl4health.utils.dataset import DictionaryDataset, TensorDataset, select_by_indices @@ -95,7 +97,13 @@ def _get_random_subsample(self, tensor_to_subsample: torch.Tensor, subsample_siz class DirichletLabelBasedSampler(LabelBasedSampler): - def __init__(self, unique_labels: List[Any], sample_percentage: float = 0.5, beta: float = 100) -> None: + def __init__( + self, + unique_labels: List[Any], + hash_key: Optional[int] = None, + sample_percentage: float = 0.5, + beta: float = 100, + ) -> None: """ class used to subsample a dataset so the classes of samples are distributed in a non-IID way. In particular, the DirichletLabelBasedSampler uses a dirichlet distribution to determine the number @@ -116,9 +124,21 @@ class used to subsample a dataset so the classes of samples are distributed in a value is 0.5 and the dataset is of size 100, we will end up with 50 total data points. Defaults to 0.5. beta (float, optional): This controls the heterogeneity of the label sampling. The smaller the beta, the more skewed the label assignments will be for the dataset. Defaults to 100. + hash_key (Optional[int], optional): Seed for the random number generators and samplers. Defaults to None. """ super().__init__(unique_labels) - self.probabilities = np.random.dirichlet(np.repeat(beta, self.num_classes)) + + self.hash_key = hash_key + + self.torch_generator = None + if self.hash_key is not None: + log(INFO, f"Setting seed to {self.hash_key} for Numpy and Torch Generators") + self.torch_generator = torch.Generator().manual_seed(self.hash_key) + + self.np_generator = np.random.default_rng(self.hash_key) + self.probabilities = self.np_generator.dirichlet(np.repeat(beta, self.num_classes)) + log(INFO, f"Setting probabilities to {self.probabilities}") + self.sample_percentage = sample_percentage def subsample(self, dataset: D) -> D: @@ -144,10 +164,13 @@ def subsample(self, dataset: D) -> D: # For each class sample the given number of samples from the class specific indices # torch.multinomial is used to uniformly sample indices the size of given number of samples sampled_class_idx_list = [ - class_idx[torch.multinomial(torch.ones(class_idx.size(0)), num_samples, replacement=True)] + class_idx[ + torch.multinomial( + torch.ones(class_idx.size(0)), num_samples, replacement=True, generator=self.torch_generator + ) + ] for class_idx, num_samples in zip(class_idx_list, num_samples_per_class) ] - selected_indices = torch.cat(sampled_class_idx_list, dim=0).long() # Due to precision errors with previous rounding, sum of sample counts diff --git a/research/cifar10/adaptive_pfl/ditto/client.py b/research/cifar10/adaptive_pfl/ditto/client.py new file mode 100644 index 000000000..f0a916ed4 --- /dev/null +++ b/research/cifar10/adaptive_pfl/ditto/client.py @@ -0,0 +1,164 @@ +import argparse +import os +from logging import INFO +from pathlib import Path +from typing import Dict, Optional, Sequence, Tuple + +import flwr as fl +import torch +import torch.nn as nn +from flwr.common.logger import log +from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.clients.ditto_client import DittoClient +from fl4health.utils.config import narrow_dict_type +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import F1, Accuracy, Metric +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNet +from research.cifar10.preprocess import get_preprocessed_data + + +class CifarDittoClient(DittoClient): + def __init__( + self, + data_path: Path, + metrics: Sequence[Metric], + device: torch.device, + client_number: int, + learning_rate: float, + heterogeneity_level: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpointer: Optional[ClientCheckpointModule] = None, + ) -> None: + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpointer=checkpointer, + ) + self.client_number = client_number + self.heterogeneity_level = heterogeneity_level + self.learning_rate: float = learning_rate + + log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") + + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + batch_size = narrow_dict_type(config, "batch_size", int) + train_loader, val_loader, _ = get_preprocessed_data( + self.data_path, self.client_number, batch_size, self.heterogeneity_level + ) + return train_loader, val_loader + + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() + + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) + local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) + return {"global": global_optimizer, "local": local_optimizer} + + def get_model(self, config: Config) -> nn.Module: + return ConvNet(in_channels=3, use_bn=False).to(self.device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Client Main") + parser.add_argument( + "--artifact_dir", + action="store", + type=str, + help="Path to save client artifacts such as logs and model checkpoints", + required=True, + ) + parser.add_argument( + "--dataset_dir", + action="store", + type=str, + help="Path to the preprocessed Cifar 10 Dataset", + required=True, + ) + parser.add_argument( + "--run_name", + action="store", + help="Name of the run, model checkpoints will be saved under a subfolder with this name", + required=True, + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address for the clients to communicate with the server through", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--client_number", + action="store", + type=int, + help="Number of the client for dataset loading", + required=True, + ) + parser.add_argument( + "--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1 + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--beta", + action="store", + type=float, + help="Heterogeneity level for the dataset", + required=True, + ) + args = parser.parse_args() + + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Learning Rate: {args.learning_rate}") + log(INFO, f"Beta: {args.beta}") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + # Adding extensive checkpointing for the client + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl" + pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" + post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" + post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" + checkpointer = ClientCheckpointModule( + pre_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + ], + post_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + ], + ) + + data_path = Path(args.dataset_dir) + client = CifarDittoClient( + data_path=data_path, + metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + heterogeneity_level=args.beta, + checkpointer=checkpointer, + ) + + fl.client.start_client(server_address=args.server_address, client=client.to_client()) + client.shutdown() diff --git a/research/cifar10/adaptive_pfl/ditto/config.yaml b/research/cifar10/adaptive_pfl/ditto/config.yaml new file mode 100644 index 000000000..323b2a693 --- /dev/null +++ b/research/cifar10/adaptive_pfl/ditto/config.yaml @@ -0,0 +1,7 @@ +# Parameters that describe server +n_server_rounds: 50 # The number of rounds to run FL + +# Parameters that describe clients +n_clients: 7 # The number of clients in the FL experiment +local_epochs: 1 # The number of epochs to complete for client +batch_size: 32 # The batch size for client training diff --git a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm new file mode 100644 index 000000000..05749e8b0 --- /dev/null +++ b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm @@ -0,0 +1,179 @@ +#!/bin/bash + +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --mem=32G +#SBATCH --partition=a40 +#SBATCH --qos=m2 +#SBATCH --job-name=fl_five_fold_exp +#SBATCH --output=%j_%x.out +#SBATCH --error=%j_%x.err +#SBATCH --time=4:00:00 + +############################################### +# Usage: +# +# sbatch research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ \ +# client_side_learning_rate_value \ +# lambda value \ +# server_address \ +# client_beta \ +# adapt +# +# Example: +# sbatch research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm \ +# research/cifar10/adaptive_pfl/ditto/config.yaml \ +# research/cifar10/adaptive_pfl/ditto/hp_results/ \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ \ +# 0.0001 \ +# 0.01 \ +# 0.0.0.0:8080\ +# 0.1 \ +# "TRUE" +# +# Notes: +# 1) The sbatch command above should be run from the top level directory of the repository. +# 2) This example runs ditto. As such the data paths and python launch commands are hardcoded. If you want to change +# the example you run, you need to explicitly modify the code below. +# 3) The logging directories need to ALREADY EXIST. The script does not create them. +############################################### + +# Note: +# ntasks: Total number of processes to use across world +# ntasks-per-node: How many processes each node should create + +# Set NCCL options +# export NCCL_DEBUG=INFO +# NCCL backend to communicate between GPU workers is not provided in vector's cluster. +# Disable this option in slurm. +export NCCL_IB_DISABLE=1 + +if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ + [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then + echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}" + export NCCL_SOCKET_IFNAME=bond0 +fi + +# Process Inputs + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 +CLIENT_LR=$5 +LAM_VALUE=$6 +SERVER_ADDRESS=$7 +CLIENT_BETA=$8 +ADAPT=$9 + +# Create the artifact directory +mkdir "${ARTIFACT_DIR}" + +RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) +SEEDS=(2021 2022 2023 2024 2025) + +echo "Python Venv Path: ${VENV_PATH}" + +echo "World size: ${SLURM_NTASKS}" +echo "Number of nodes: ${SLURM_NNODES}" +NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +echo "GPUs per node: ${NUM_GPUs}" + +# Source the environment +source ${VENV_PATH}bin/activate +echo "Active Environment:" +which python + +for ((i=0; i<${#RUN_NAMES[@]}; i++)); +do + RUN_NAME="${RUN_NAMES[i]}" + SEED="${SEEDS[i]}" + # create the run directory + RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" + echo "Starting Run and logging artifcats at ${RUN_DIR}" + if [ -d "${RUN_DIR}" ] + then + # Directory already exists, we check if the done.out file exists + if [ -f "${RUN_DIR}done.out" ] + then + # Done file already exists so we skip this run + echo "Run already completed. Skipping Run." + continue + else + # Done file doesn't exists (assume pre-emption happened) + # Delete the partially finished contents and start over + echo "Run did not finished correctly. Re-running." + rm -r "${RUN_DIR}" + mkdir "${RUN_DIR}" + fi + else + # Directory doesn't exist yet, so we create it. + echo "Run directory does not exist. Creating it." + mkdir "${RUN_DIR}" + fi + + SERVER_OUTPUT_FILE="${RUN_DIR}server.out" + + # Start the server, divert the outputs to a server file + + echo "Server logging at: ${SERVER_OUTPUT_FILE}" + echo "Launching Server" + + if [[ ${ADAPT} == "TRUE" ]]; then + nohup python -m research.cifar10.adaptive_pfl.ditto.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + --use_adaptation \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + else + nohup python -m research.cifar10.adaptive_pfl.ditto.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + fi + + # Sleep for 20 seconds to allow the server to come up. + sleep 20 + + # Start n number of clients and divert the outputs to their own files + n_clients=7 + for (( c=0; c<${n_clients}; c++ )) + do + CLIENT_NAME="client_${c}" + echo "Launching ${CLIENT_NAME}" + + CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out" + echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}" + nohup python -m research.cifar10.adaptive_pfl.ditto.client \ + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + > ${CLIENT_LOG_PATH} 2>&1 & + done + + echo "FL Processes Running" + + wait + + # Create a file that verifies that the Run concluded properly + touch "${RUN_DIR}done.out" + echo "Finished FL Processes" + +done diff --git a/research/cifar10/adaptive_pfl/ditto/run_hp_sweep.sh b/research/cifar10/adaptive_pfl/ditto/run_hp_sweep.sh new file mode 100755 index 000000000..6eff88615 --- /dev/null +++ b/research/cifar10/adaptive_pfl/ditto/run_hp_sweep.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +############################################### +# Usage: +# +# ./research/cifar10/adaptive_pfl/ditto/run_hp_sweep.sh \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ +# +# Example: +# ./research/cifar10/adaptive_pfl/ditto/run_hp_sweep.sh \ +# research/cifar10/adaptive_pfl/ditto/config.yaml \ +# research/cifar10/adaptive_pfl/ditto \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ +# +# Notes: +# 1) The bash command above should be run from the top level directory of the repository. +############################################### + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 + +LR_VALUES=( 0.0001 0.001 0.01 0.1 ) +# Note: These values must correspond to values for the preprocessed CIFAR datasets +BETA_VALUES=( 0.1 0.5 5.0 ) +LAM_VALUES=( 0.001 1.0 ) +ADAPT_BOOL=( "TRUE" "FALSE" ) + +SERVER_PORT=8100 + +# Create sweep folder +SWEEP_DIRECTORY="${ARTIFACT_DIR}hp_sweep_results" +echo "Creating sweep folder at ${SWEEP_DIRECTORY}" +mkdir ${SWEEP_DIRECTORY} + +for BETA_VALUE in "${BETA_VALUES[@]}"; do + echo "Creating folder for beta ${BETA_VALUE}" + mkdir "${SWEEP_DIRECTORY}/beta_${BETA_VALUE}" + for LR_VALUE in "${LR_VALUES[@]}"; + do + for LAM_VALUE in "${LAM_VALUES[@]}"; + do + for ADAPT in "${ADAPT_BOOL[@]}"; + do + EXPERIMENT_NAME="lr_${LR_VALUE}_beta_${BETA_VALUE}_lam_${LAM_VALUE}" + if [[ ${ADAPT} == "TRUE" ]]; then + EXPERIMENT_NAME="${EXPERIMENT_NAME}_adapt" + fi + echo "Beginning Experiment ${EXPERIMENT_NAME}" + EXPERIMENT_DIRECTORY="${SWEEP_DIRECTORY}/beta_${BETA_VALUE}/${EXPERIMENT_NAME}/" + echo "Creating experiment folder ${EXPERIMENT_DIRECTORY}" + mkdir "${EXPERIMENT_DIRECTORY}" + SERVER_ADDRESS="0.0.0.0:${SERVER_PORT}" + echo "Server Address: ${SERVER_ADDRESS}" + SBATCH_COMMAND="research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm \ + ${SERVER_CONFIG_PATH} \ + ${EXPERIMENT_DIRECTORY} \ + ${DATASET_DIR} \ + ${VENV_PATH} \ + ${LR_VALUE} \ + ${LAM_VALUE} \ + ${SERVER_ADDRESS} \ + ${BETA_VALUE} \ + ${ADAPT}" + sbatch ${SBATCH_COMMAND} + ((SERVER_PORT=SERVER_PORT+1)) + done + done + done +done +echo Experiments Launched diff --git a/research/cifar10/adaptive_pfl/ditto/server.py b/research/cifar10/adaptive_pfl/ditto/server.py new file mode 100644 index 000000000..8fdeecaaf --- /dev/null +++ b/research/cifar10/adaptive_pfl/ditto/server.py @@ -0,0 +1,147 @@ +import argparse +from functools import partial +from logging import INFO +from typing import Any, Dict, Optional + +import flwr as fl +from flwr.common.logger import log +from flwr.common.typing import Config +from flwr.server.client_manager import ClientManager, SimpleClientManager +from flwr.server.strategy import Strategy + +from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint +from fl4health.utils.config import load_config +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from fl4health.utils.parameter_extraction import get_all_model_parameters +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNet +from research.cifar10.personal_server import PersonalServer + + +class PersonalDittoServer(PersonalServer): + """ + The PersonalServer class is used for FL approaches that only have a sense of a PERSONAL model that is checkpointed + and valid only on the client size of the FL training framework. FL approaches like APFL and FENDA fall under this + category. Each client will have its own model that is specific to its own training. Personal models may have + shared components but the full model is specific to each client. This is distinct from the + FlServerWithCheckpointing class which has a sense of a GLOBAL model checkpointed on the server-side that is + shared by all clients. + """ + + def __init__( + self, + client_manager: ClientManager, + strategy: Optional[Strategy] = None, + ) -> None: + assert isinstance( + strategy, FedAvgWithAdaptiveConstraint + ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" + # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with + # some globally shared weights. So we don't checkpoint a global model + super().__init__(client_manager, strategy) + + +def fit_config( + batch_size: int, + local_epochs: int, + n_server_rounds: int, + n_clients: int, + current_server_round: int, +) -> Config: + return { + "batch_size": batch_size, + "local_epochs": local_epochs, + "n_server_rounds": n_server_rounds, + "n_clients": n_clients, + "current_server_round": current_server_round, + } + + +def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_weight: bool) -> None: + # This function will be used to produce a config that is sent to each client to initialize their own environment + fit_config_fn = partial( + fit_config, + config["batch_size"], + config["local_epochs"], + config["n_server_rounds"], + config["n_clients"], + ) + + client_manager = SimpleClientManager() + # Initializing the model on the server side + model = ConvNet(in_channels=3, use_bn=False) + # Server performs simple FedAveraging as its server-side optimization strategy + strategy = FedAvgWithAdaptiveConstraint( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + # Server waits for min_available_clients before starting FL rounds + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + # We use the same fit config function, as nothing changes for eval + on_evaluate_config_fn=fit_config_fn, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=get_all_model_parameters(model), + initial_loss_weight=lam, + adapt_loss_weight=adapt_loss_weight, + ) + + server = PersonalDittoServer(client_manager=client_manager, strategy=strategy) + + fl.server.start_server( + server=server, + server_address=server_address, + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + + log(INFO, "Training Complete") + log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{server.best_aggregated_loss}") + + # Shutdown the server gracefully + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Server Main") + parser.add_argument( + "--config_path", + action="store", + type=str, + help="Path to configuration file.", + default="config.yaml", + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address to be used to communicate with the clients", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01 + ) + parser.add_argument( + "--use_adaptation", + action="store_true", + help="Whether or not the loss weight for model drift is adapted or remains fixed.", + default=False, + ) + args = parser.parse_args() + + config = load_config(args.config_path) + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Lambda: {args.lam}") + if args.use_adaptation: + log(INFO, "Adapting the loss weight for model drift via global model loss") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + main(config, args.server_address, args.lam, args.use_adaptation) diff --git a/research/cifar10/adaptive_pfl/fedprox/__init__.py b/research/cifar10/adaptive_pfl/fedprox/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/research/cifar10/adaptive_pfl/fedprox/client.py b/research/cifar10/adaptive_pfl/fedprox/client.py new file mode 100644 index 000000000..ed56117a0 --- /dev/null +++ b/research/cifar10/adaptive_pfl/fedprox/client.py @@ -0,0 +1,168 @@ +import argparse +import os +from logging import INFO +from pathlib import Path +from typing import Optional, Sequence, Tuple + +import flwr as fl +import torch +import torch.nn as nn +from flwr.common.logger import log +from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.clients.fed_prox_client import FedProxClient +from fl4health.utils.config import narrow_dict_type +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import F1, Accuracy, Metric +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNet +from research.cifar10.preprocess import get_preprocessed_data + + +class CifarFedProxClient(FedProxClient): + def __init__( + self, + data_path: Path, + metrics: Sequence[Metric], + device: torch.device, + client_number: int, + learning_rate: float, + heterogeneity_level: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpointer: Optional[ClientCheckpointModule] = None, + ) -> None: + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpointer=checkpointer, + ) + self.client_number = client_number + self.heterogeneity_level = heterogeneity_level + self.learning_rate: float = learning_rate + + log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") + + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + batch_size = narrow_dict_type(config, "batch_size", int) + train_loader, val_loader, _ = get_preprocessed_data( + self.data_path, self.client_number, batch_size, self.heterogeneity_level + ) + return train_loader, val_loader + + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() + + def get_optimizer(self, config: Config) -> Optimizer: + return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) + + def get_model(self, config: Config) -> nn.Module: + return ConvNet(in_channels=3, use_bn=False).to(self.device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Client Main") + parser.add_argument( + "--artifact_dir", + action="store", + type=str, + help="Path to save client artifacts such as logs and model checkpoints", + required=True, + ) + parser.add_argument( + "--dataset_dir", + action="store", + type=str, + help="Path to the preprocessed Cifar 10 Dataset", + required=True, + ) + parser.add_argument( + "--use_partitioned_data", + action="store_true", + help="Use preprocessed partitioned data for training, validation and testing", + default=False, + ) + parser.add_argument( + "--run_name", + action="store", + help="Name of the run, model checkpoints will be saved under a subfolder with this name", + required=True, + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address for the clients to communicate with the server through", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--client_number", + action="store", + type=int, + help="Number of the client for dataset loading", + required=True, + ) + parser.add_argument( + "--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1 + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--beta", + action="store", + type=float, + help="Heterogeneity level for the dataset", + required=True, + ) + args = parser.parse_args() + + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Learning Rate: {args.learning_rate}") + log(INFO, f"Beta: {args.beta}") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + # Adding extensive checkpointing for the client + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl" + pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" + post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" + post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" + checkpointer = ClientCheckpointModule( + pre_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + ], + post_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + ], + ) + + data_path = Path(args.dataset_dir) + client = CifarFedProxClient( + data_path=data_path, + metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + heterogeneity_level=args.beta, + checkpointer=checkpointer, + ) + + fl.client.start_client(server_address=args.server_address, client=client.to_client()) + client.shutdown() diff --git a/research/cifar10/adaptive_pfl/fedprox/config.yaml b/research/cifar10/adaptive_pfl/fedprox/config.yaml new file mode 100644 index 000000000..323b2a693 --- /dev/null +++ b/research/cifar10/adaptive_pfl/fedprox/config.yaml @@ -0,0 +1,7 @@ +# Parameters that describe server +n_server_rounds: 50 # The number of rounds to run FL + +# Parameters that describe clients +n_clients: 7 # The number of clients in the FL experiment +local_epochs: 1 # The number of epochs to complete for client +batch_size: 32 # The batch size for client training diff --git a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm new file mode 100644 index 000000000..47c3dbaf1 --- /dev/null +++ b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm @@ -0,0 +1,183 @@ +#!/bin/bash + +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --mem=32G +#SBATCH --partition=a40 +#SBATCH --qos=m2 +#SBATCH --job-name=fl_five_fold_exp +#SBATCH --output=%j_%x.out +#SBATCH --error=%j_%x.err +#SBATCH --time=4:00:00 + +############################################### +# Usage: +# +# sbatch research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ \ +# client_side_learning_rate_value \ +# lambda value \ +# server_address \ +# client_beta \ +# adapt +# +# Example: +# sbatch research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm \ +# research/cifar10/adaptive_pfl/fedprox/config.yaml \ +# research/cifar10/adaptive_pfl/fedprox/hp_results/ \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ \ +# 0.0001 \ +# 0.01 \ +# 0.0.0.0:8080\ +# 0.1 +# "TRUE" +# +# Notes: +# 1) The sbatch command above should be run from the top level directory of the repository. +# 2) This example runs fedprox. As such the data paths and python launch commands are hardcoded. If you want to change +# the example you run, you need to explicitly modify the code below. +# 3) The logging directories need to ALREADY EXIST. The script does not create them. +############################################### + +# Note: +# ntasks: Total number of processes to use across world +# ntasks-per-node: How many processes each node should create + +# Set NCCL options +# export NCCL_DEBUG=INFO +# NCCL backend to communicate between GPU workers is not provided in vector's cluster. +# Disable this option in slurm. +export NCCL_IB_DISABLE=1 + +if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ + [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then + echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}" + export NCCL_SOCKET_IFNAME=bond0 +fi + +# Process Inputs + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 +CLIENT_LR=$5 +LAM_VALUE=$6 +SERVER_ADDRESS=$7 +CLIENT_BETA=$8 +ADAPT=$9 + +# Create the artifact directory +mkdir "${ARTIFACT_DIR}" + +RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) +SEEDS=(2021 2022 2023 2024 2025) + +echo "Python Venv Path: ${VENV_PATH}" + +echo "World size: ${SLURM_NTASKS}" +echo "Number of nodes: ${SLURM_NNODES}" +NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +echo "GPUs per node: ${NUM_GPUs}" + +# Source the environment +source ${VENV_PATH}bin/activate +echo "Active Environment:" +which python + +for ((i=0; i<${#RUN_NAMES[@]}; i++)); +do + RUN_NAME="${RUN_NAMES[i]}" + SEED="${SEEDS[i]}" + # create the run directory + RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" + echo "Starting Run and logging artifcats at ${RUN_DIR}" + if [ -d "${RUN_DIR}" ] + then + # Directory already exists, we check if the done.out file exists + if [ -f "${RUN_DIR}done.out" ] + then + # Done file already exists so we skip this run + echo "Run already completed. Skipping Run." + continue + else + # Done file doesn't exists (assume pre-emption happened) + # Delete the partially finished contents and start over + echo "Run did not finished correctly. Re-running." + rm -r "${RUN_DIR}" + mkdir "${RUN_DIR}" + fi + else + # Directory doesn't exist yet, so we create it. + echo "Run directory does not exist. Creating it." + mkdir "${RUN_DIR}" + fi + + SERVER_OUTPUT_FILE="${RUN_DIR}server.out" + + # Start the server, divert the outputs to a server file + + echo "Server logging at: ${SERVER_OUTPUT_FILE}" + echo "Launching Server" + + if [[ ${ADAPT} == "TRUE" ]]; then + nohup python -m research.cifar10.adaptive_pfl.fedprox.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --artifact_dir ${ARTIFACT_DIR} \ + --run_name ${RUN_NAME} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + --use_adaptation \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + else + nohup python -m research.cifar10.adaptive_pfl.fedprox.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --artifact_dir ${ARTIFACT_DIR} \ + --run_name ${RUN_NAME} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + fi + + # Sleep for 20 seconds to allow the server to come up. + sleep 20 + + # Start n number of clients and divert the outputs to their own files + n_clients=7 + for (( c=0; c<${n_clients}; c++ )) + do + CLIENT_NAME="client_${c}" + echo "Launching ${CLIENT_NAME}" + + CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out" + echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}" + nohup python -m research.cifar10.adaptive_pfl.fedprox.client \ + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + > ${CLIENT_LOG_PATH} 2>&1 & + done + + echo "FL Processes Running" + + wait + + # Create a file that verifies that the Run concluded properly + touch "${RUN_DIR}done.out" + echo "Finished FL Processes" + +done diff --git a/research/cifar10/adaptive_pfl/fedprox/run_hp_sweep.sh b/research/cifar10/adaptive_pfl/fedprox/run_hp_sweep.sh new file mode 100755 index 000000000..4df2adfc3 --- /dev/null +++ b/research/cifar10/adaptive_pfl/fedprox/run_hp_sweep.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +############################################### +# Usage: +# +# ./research/cifar10/adaptive_pfl/fedprox/run_hp_sweep.sh \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ +# +# Example: +# ./research/cifar10/adaptive_pfl/fedprox/run_hp_sweep.sh \ +# research/cifar10/adaptive_pfl/fedprox/config.yaml \ +# research/cifar10/adaptive_pfl/fedprox/ \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ +# +# Notes: +# 1) The bash command above should be run from the top level directory of the repository. +############################################### + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 + +LR_VALUES=( 0.0001 0.001 0.01 0.1 ) +# Note: These values must correspond to values for the preprocessed CIFAR datasets +BETA_VALUES=( 0.1 0.5 5.0 ) +LAM_VALUES=( 0.001 1.0 ) +ADAPT_BOOL=( "TRUE" "FALSE" ) + +SERVER_PORT=8100 + +# Create sweep folder +SWEEP_DIRECTORY="${ARTIFACT_DIR}hp_sweep_results" +echo "Creating sweep folder at ${SWEEP_DIRECTORY}" +mkdir ${SWEEP_DIRECTORY} + +for BETA_VALUE in "${BETA_VALUES[@]}"; do + echo "Creating folder for beta ${BETA_VALUE}" + mkdir "${SWEEP_DIRECTORY}/beta_${BETA_VALUE}" + for LR_VALUE in "${LR_VALUES[@]}"; + do + for LAM_VALUE in "${LAM_VALUES[@]}"; + do + for ADAPT in "${ADAPT_BOOL[@]}"; + do + EXPERIMENT_NAME="lr_${LR_VALUE}_beta_${BETA_VALUE}_lam_${LAM_VALUE}" + if [[ ${ADAPT} == "TRUE" ]]; then + EXPERIMENT_NAME="${EXPERIMENT_NAME}_adapt" + fi + echo "Beginning Experiment ${EXPERIMENT_NAME}" + EXPERIMENT_DIRECTORY="${SWEEP_DIRECTORY}/beta_${BETA_VALUE}/${EXPERIMENT_NAME}/" + echo "Creating experiment folder ${EXPERIMENT_DIRECTORY}" + mkdir "${EXPERIMENT_DIRECTORY}" + SERVER_ADDRESS="0.0.0.0:${SERVER_PORT}" + echo "Server Address: ${SERVER_ADDRESS}" + SBATCH_COMMAND="research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm \ + ${SERVER_CONFIG_PATH} \ + ${EXPERIMENT_DIRECTORY} \ + ${DATASET_DIR} \ + ${VENV_PATH} \ + ${LR_VALUE} \ + ${LAM_VALUE} \ + ${SERVER_ADDRESS} \ + ${BETA_VALUE} \ + ${ADAPT}" + sbatch ${SBATCH_COMMAND} + ((SERVER_PORT=SERVER_PORT+1)) + done + done + done +done +echo Experiments Launched diff --git a/research/cifar10/adaptive_pfl/fedprox/server.py b/research/cifar10/adaptive_pfl/fedprox/server.py new file mode 100644 index 000000000..3206d4ef2 --- /dev/null +++ b/research/cifar10/adaptive_pfl/fedprox/server.py @@ -0,0 +1,159 @@ +import argparse +import os +from functools import partial +from logging import INFO +from typing import Any, Dict + +import flwr as fl +from flwr.common.logger import log +from flwr.common.typing import Config +from flwr.server.client_manager import SimpleClientManager + +from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.server.adaptive_constraint_servers.fedprox_server import FedProxServer +from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint +from fl4health.utils.config import load_config +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from fl4health.utils.parameter_extraction import get_all_model_parameters +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNet + + +def fit_config( + batch_size: int, + local_epochs: int, + n_server_rounds: int, + n_clients: int, + current_server_round: int, +) -> Config: + return { + "batch_size": batch_size, + "local_epochs": local_epochs, + "n_server_rounds": n_server_rounds, + "n_clients": n_clients, + "current_server_round": current_server_round, + } + + +def main( + config: Dict[str, Any], + server_address: str, + checkpoint_stub: str, + run_name: str, + lam: float, + adapt_loss_weight: bool, +) -> None: + # This function will be used to produce a config that is sent to each client to initialize their own environment + fit_config_fn = partial( + fit_config, + config["batch_size"], + config["local_epochs"], + config["n_server_rounds"], + config["n_clients"], + ) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) + best_checkpoint_name = "server_best_model.pkl" + last_checkpoint_name = "server_last_model.pkl" + checkpointer = [ + BestLossTorchCheckpointer(checkpoint_dir, best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, last_checkpoint_name), + ] + + client_manager = SimpleClientManager() + # Initializing the model on the server side + model = ConvNet(in_channels=3, use_bn=False) + # Server performs simple FedAveraging as its server-side optimization strategy + strategy = FedAvgWithAdaptiveConstraint( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + # Server waits for min_available_clients before starting FL rounds + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + # We use the same fit config function, as nothing changes for eval + on_evaluate_config_fn=fit_config_fn, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=get_all_model_parameters(model), + initial_loss_weight=lam, + adapt_loss_weight=adapt_loss_weight, + ) + + server = FedProxServer( + client_manager=client_manager, + model=model, + wandb_reporter=None, + strategy=strategy, + checkpointer=checkpointer, + ) + + fl.server.start_server( + server=server, + server_address=server_address, + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + + log(INFO, "Training Complete") + assert isinstance(checkpointer[0], BestLossTorchCheckpointer) + log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer[0].best_score}") + + # Shutdown the server gracefully + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Server Main") + parser.add_argument( + "--artifact_dir", + action="store", + type=str, + help="Path to save server artifacts such as logs and model checkpoints", + required=True, + ) + parser.add_argument( + "--run_name", + action="store", + help="Name of the run, model checkpoints will be saved under a subfolder with this name", + required=True, + ) + parser.add_argument( + "--config_path", + action="store", + type=str, + help="Path to configuration file.", + default="config.yaml", + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address to be used to communicate with the clients", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--lam", action="store", type=float, help="FedProx loss weight for local model training", default=0.01 + ) + parser.add_argument( + "--use_adaptation", + action="store_true", + help="Whether or not the loss weight for model drift is adapted or remains fixed.", + default=False, + ) + args = parser.parse_args() + + config = load_config(args.config_path) + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Lambda: {args.lam}") + if args.use_adaptation: + log(INFO, "Adapting the loss weight for model drift via model loss") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + main(config, args.server_address, args.artifact_dir, args.run_name, args.lam, args.use_adaptation) diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/client.py b/research/cifar10/adaptive_pfl/fenda_ditto/client.py new file mode 100644 index 000000000..4fa23e727 --- /dev/null +++ b/research/cifar10/adaptive_pfl/fenda_ditto/client.py @@ -0,0 +1,179 @@ +import argparse +import os +from logging import INFO +from pathlib import Path +from typing import Dict, Optional, Sequence, Tuple + +import flwr as fl +import torch +from flwr.common.logger import log +from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.clients.fenda_ditto_client import FendaDittoClient +from fl4health.model_bases.fenda_base import FendaModel +from fl4health.model_bases.sequential_split_models import SequentiallySplitModel +from fl4health.utils.config import narrow_dict_type +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import F1, Accuracy, Metric +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNetFendaDittoGlobalModel, ConvNetFendaModel +from research.cifar10.preprocess import get_preprocessed_data + + +class CifarFendaDittoClient(FendaDittoClient): + def __init__( + self, + data_path: Path, + metrics: Sequence[Metric], + device: torch.device, + client_number: int, + learning_rate: float, + heterogeneity_level: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpointer: Optional[ClientCheckpointModule] = None, + freeze_global_feature_extractor: bool = False, + ) -> None: + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpointer=checkpointer, + freeze_global_feature_extractor=freeze_global_feature_extractor, + ) + self.client_number = client_number + self.heterogeneity_level = heterogeneity_level + self.learning_rate: float = learning_rate + + log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") + + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + batch_size = narrow_dict_type(config, "batch_size", int) + train_loader, val_loader, _ = get_preprocessed_data( + self.data_path, self.client_number, batch_size, self.heterogeneity_level + ) + return train_loader, val_loader + + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() + + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) + local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) + return {"global": global_optimizer, "local": local_optimizer} + + def get_model(self, config: Config) -> FendaModel: + return ConvNetFendaModel(in_channels=3, use_bn=False).to(self.device) + + def get_global_model(self, config: Config) -> SequentiallySplitModel: + return ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False).to(self.device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Client Main") + parser.add_argument( + "--artifact_dir", + action="store", + type=str, + help="Path to save client artifacts such as logs and model checkpoints", + required=True, + ) + parser.add_argument( + "--dataset_dir", + action="store", + type=str, + help="Path to the preprocessed Cifar 10 Dataset", + required=True, + ) + parser.add_argument( + "--run_name", + action="store", + help="Name of the run, model checkpoints will be saved under a subfolder with this name", + required=True, + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address for the clients to communicate with the server through", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--client_number", + action="store", + type=int, + help="Number of the client for dataset loading", + required=True, + ) + parser.add_argument( + "--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1 + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--beta", + action="store", + type=float, + help="Heterogeneity level for the dataset", + required=True, + ) + parser.add_argument( + "--freeze_global_extractor", + action="store_true", + help="Whether or not to freeze the global feature extractor of the FENDA model or not.", + default=False, + ) + args = parser.parse_args() + + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Learning Rate: {args.learning_rate}") + log(INFO, f"Beta: {args.beta}") + if args.freeze_global_extractor: + log(INFO, "Freezing the global feature extractor of the FENDA model") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + # Adding extensive checkpointing for the client + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl" + pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" + post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" + post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" + checkpointer = ClientCheckpointModule( + pre_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + ], + post_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + ], + ) + + data_path = Path(args.dataset_dir) + client = CifarFendaDittoClient( + data_path=data_path, + metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + heterogeneity_level=args.beta, + checkpointer=checkpointer, + freeze_global_feature_extractor=args.freeze_global_extractor, + ) + + fl.client.start_client(server_address=args.server_address, client=client.to_client()) + client.shutdown() diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/config.yaml b/research/cifar10/adaptive_pfl/fenda_ditto/config.yaml new file mode 100644 index 000000000..323b2a693 --- /dev/null +++ b/research/cifar10/adaptive_pfl/fenda_ditto/config.yaml @@ -0,0 +1,7 @@ +# Parameters that describe server +n_server_rounds: 50 # The number of rounds to run FL + +# Parameters that describe clients +n_clients: 7 # The number of clients in the FL experiment +local_epochs: 1 # The number of epochs to complete for client +batch_size: 32 # The batch size for client training diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm new file mode 100644 index 000000000..68c6589ae --- /dev/null +++ b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm @@ -0,0 +1,198 @@ +#!/bin/bash + +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --mem=32G +#SBATCH --partition=a40 +#SBATCH --qos=m2 +#SBATCH --job-name=fl_five_fold_exp +#SBATCH --output=%j_%x.out +#SBATCH --error=%j_%x.err +#SBATCH --time=4:00:00 + +############################################### +# Usage: +# +# sbatch research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ \ +# client_side_learning_rate_value \ +# lambda value \ +# server_address\ +# client_beta \ +# adapt \ +# freeze +# +# Example: +# sbatch research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm \ +# research/cifar10/adaptive_pfl/fenda_ditto/config.yaml \ +# research/cifar10/adaptive_pfl/fenda_ditto/hp_results/ \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ \ +# 0.0001 \ +# 0.01 \ +# 0.0.0.0:8080 \ +# 0.1 \ +# "TRUE" \ +# "TRUE" +# +# Notes: +# 1) The sbatch command above should be run from the top level directory of the repository. +# 2) This example runs fenda_ditto. As such the data paths and python launch commands are hardcoded. If you want to change +# the example you run, you need to explicitly modify the code below. +# 3) The logging directories need to ALREADY EXIST. The script does not create them. +############################################### + +# Note: +# ntasks: Total number of processes to use across world +# ntasks-per-node: How many processes each node should create + +# Set NCCL options +# export NCCL_DEBUG=INFO +# NCCL backend to communicate between GPU workers is not provided in vector's cluster. +# Disable this option in slurm. +export NCCL_IB_DISABLE=1 + +if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ + [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then + echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}" + export NCCL_SOCKET_IFNAME=bond0 +fi + +# Process Inputs + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 +CLIENT_LR=$5 +LAM_VALUE=$6 +SERVER_ADDRESS=$7 +CLIENT_BETA=$8 +ADAPT=$9 +FREEZE=${10} + +# Create the artifact directory +mkdir "${ARTIFACT_DIR}" + +RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) +SEEDS=(2021 2022 2023 2024 2025) + +echo "Python Venv Path: ${VENV_PATH}" + +echo "World size: ${SLURM_NTASKS}" +echo "Number of nodes: ${SLURM_NNODES}" +NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +echo "GPUs per node: ${NUM_GPUs}" + +# Source the environment +source ${VENV_PATH}bin/activate +echo "Active Environment:" +which python + +for ((i=0; i<${#RUN_NAMES[@]}; i++)); +do + RUN_NAME="${RUN_NAMES[i]}" + SEED="${SEEDS[i]}" + # create the run directory + RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" + echo "Starting Run and logging artifcats at ${RUN_DIR}" + if [ -d "${RUN_DIR}" ] + then + # Directory already exists, we check if the done.out file exists + if [ -f "${RUN_DIR}done.out" ] + then + # Done file already exists so we skip this run + echo "Run already completed. Skipping Run." + continue + else + # Done file doesn't exists (assume pre-emption happened) + # Delete the partially finished contents and start over + echo "Run did not finished correctly. Re-running." + rm -r "${RUN_DIR}" + mkdir "${RUN_DIR}" + fi + else + # Directory doesn't exist yet, so we create it. + echo "Run directory does not exist. Creating it." + mkdir "${RUN_DIR}" + fi + + SERVER_OUTPUT_FILE="${RUN_DIR}server.out" + + # Start the server, divert the outputs to a server file + + echo "Server logging at: ${SERVER_OUTPUT_FILE}" + echo "Launching Server" + + if [[ ${ADAPT} == "TRUE" ]]; then + nohup python -m research.cifar10.adaptive_pfl.fenda_ditto.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + --use_adaptation \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + else + nohup python -m research.cifar10.adaptive_pfl.fenda_ditto.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + fi + + # Sleep for 20 seconds to allow the server to come up. + sleep 20 + + # Start n number of clients and divert the outputs to their own files + n_clients=7 + for (( c=0; c<${n_clients}; c++ )) + do + CLIENT_NAME="client_${c}" + echo "Launching ${CLIENT_NAME}" + + CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out" + echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}" + + if [[ ${ADAPT} == "TRUE" ]]; then + nohup python -m research.cifar10.adaptive_pfl.fenda_ditto.client \ + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + --freeze_global_extractor \ + > ${CLIENT_LOG_PATH} 2>&1 & + else + nohup python -m research.cifar10.adaptive_pfl.fenda_ditto.client \ + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + > ${CLIENT_LOG_PATH} 2>&1 & + fi + + done + + echo "FL Processes Running" + + wait + + # Create a file that verifies that the Run concluded properly + touch "${RUN_DIR}done.out" + echo "Finished FL Processes" + +done diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/run_hp_sweep.sh b/research/cifar10/adaptive_pfl/fenda_ditto/run_hp_sweep.sh new file mode 100755 index 000000000..194f8330e --- /dev/null +++ b/research/cifar10/adaptive_pfl/fenda_ditto/run_hp_sweep.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +############################################### +# Usage: +# +# ./research/cifar10/adaptive_pfl/fenda_ditto/run_hp_sweep.sh \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ +# +# Example: +# ./research/cifar10/adaptive_pfl/fenda_ditto/run_hp_sweep.sh \ +# research/cifar10/adaptive_pfl/fenda_ditto/config.yaml \ +# research/cifar10/adaptive_pfl/fenda_ditto \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ +# +# Notes: +# 1) The bash command above should be run from the top level directory of the repository. +############################################### + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 + +LR_VALUES=( 0.0001 0.001 0.01 0.1 ) +# Note: These values must correspond to values for the preprocessed CIFAR datasets +BETA_VALUES=( 0.1 0.5 5.0 ) +LAM_VALUES=( 0.001 1.0 ) +ADAPT_BOOL=( "TRUE" "FALSE" ) +FREEZE_FEATURE_EXTRACTOR=( "TRUE" "FALSE" ) + +SERVER_PORT=8100 + +# Create sweep folder +SWEEP_DIRECTORY="${ARTIFACT_DIR}hp_sweep_results" +echo "Creating sweep folder at ${SWEEP_DIRECTORY}" +mkdir ${SWEEP_DIRECTORY} + +for BETA_VALUE in "${BETA_VALUES[@]}"; do + echo "Creating folder for beta ${BETA_VALUE}" + mkdir "${SWEEP_DIRECTORY}/beta_${BETA_VALUE}" + for LR_VALUE in "${LR_VALUES[@]}"; + do + for LAM_VALUE in "${LAM_VALUES[@]}"; + do + for ADAPT in "${ADAPT_BOOL[@]}"; + do + for FREEZE in "${FREEZE_FEATURE_EXTRACTOR[@]}"; + do + EXPERIMENT_NAME="lr_${LR_VALUE}_beta_${BETA_VALUE}_lam_${LAM_VALUE}" + if [[ ${ADAPT} == "TRUE" ]]; then + EXPERIMENT_NAME="${EXPERIMENT_NAME}_adapt" + fi + if [[ ${FREEZE} == "TRUE" ]]; then + EXPERIMENT_NAME="${EXPERIMENT_NAME}_freeze" + fi + echo "Beginning Experiment ${EXPERIMENT_NAME}" + EXPERIMENT_DIRECTORY="${SWEEP_DIRECTORY}/beta_${BETA_VALUE}/${EXPERIMENT_NAME}/" + echo "Creating experiment folder ${EXPERIMENT_DIRECTORY}" + mkdir "${EXPERIMENT_DIRECTORY}" + SERVER_ADDRESS="0.0.0.0:${SERVER_PORT}" + echo "Server Address: ${SERVER_ADDRESS}" + SBATCH_COMMAND="research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm \ + ${SERVER_CONFIG_PATH} \ + ${EXPERIMENT_DIRECTORY} \ + ${DATASET_DIR} \ + ${VENV_PATH} \ + ${LR_VALUE} \ + ${LAM_VALUE} \ + ${SERVER_ADDRESS} \ + ${BETA_VALUE} \ + ${ADAPT} \ + ${FREEZE}" + sbatch ${SBATCH_COMMAND} + ((SERVER_PORT=SERVER_PORT+1)) + done + done + done + done +done +echo Experiments Launched diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/server.py b/research/cifar10/adaptive_pfl/fenda_ditto/server.py new file mode 100644 index 000000000..767583c80 --- /dev/null +++ b/research/cifar10/adaptive_pfl/fenda_ditto/server.py @@ -0,0 +1,147 @@ +import argparse +from functools import partial +from logging import INFO +from typing import Any, Dict, Optional + +import flwr as fl +from flwr.common.logger import log +from flwr.common.typing import Config +from flwr.server.client_manager import ClientManager, SimpleClientManager +from flwr.server.strategy import Strategy + +from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint +from fl4health.utils.config import load_config +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from fl4health.utils.parameter_extraction import get_all_model_parameters +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNet +from research.cifar10.personal_server import PersonalServer + + +class PersonalFendaDittoServer(PersonalServer): + """ + The PersonalServer class is used for FL approaches that only have a sense of a PERSONAL model that is checkpointed + and valid only on the client size of the FL training framework. FL approaches like APFL and FENDA fall under this + category. Each client will have its own model that is specific to its own training. Personal models may have + shared components but the full model is specific to each client. This is distinct from the + FlServerWithCheckpointing class which has a sense of a GLOBAL model checkpointed on the server-side that is + shared by all clients. + """ + + def __init__( + self, + client_manager: ClientManager, + strategy: Optional[Strategy] = None, + ) -> None: + assert isinstance( + strategy, FedAvgWithAdaptiveConstraint + ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" + # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with + # some globally shared weights. So we don't checkpoint a global model + super().__init__(client_manager, strategy) + + +def fit_config( + batch_size: int, + local_epochs: int, + n_server_rounds: int, + n_clients: int, + current_server_round: int, +) -> Config: + return { + "batch_size": batch_size, + "local_epochs": local_epochs, + "n_server_rounds": n_server_rounds, + "n_clients": n_clients, + "current_server_round": current_server_round, + } + + +def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_weight: bool) -> None: + # This function will be used to produce a config that is sent to each client to initialize their own environment + fit_config_fn = partial( + fit_config, + config["batch_size"], + config["local_epochs"], + config["n_server_rounds"], + config["n_clients"], + ) + + client_manager = SimpleClientManager() + # Initializing the model on the server side + model = ConvNet(in_channels=3, use_bn=False) + # Server performs simple FedAveraging as its server-side optimization strategy + strategy = FedAvgWithAdaptiveConstraint( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + # Server waits for min_available_clients before starting FL rounds + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + # We use the same fit config function, as nothing changes for eval + on_evaluate_config_fn=fit_config_fn, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=get_all_model_parameters(model), + initial_loss_weight=lam, + adapt_loss_weight=adapt_loss_weight, + ) + + server = PersonalFendaDittoServer(client_manager=client_manager, strategy=strategy) + + fl.server.start_server( + server=server, + server_address=server_address, + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + + log(INFO, "Training Complete") + log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{server.best_aggregated_loss}") + + # Shutdown the server gracefully + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Server Main") + parser.add_argument( + "--config_path", + action="store", + type=str, + help="Path to configuration file.", + default="config.yaml", + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address to be used to communicate with the clients", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--lam", action="store", type=float, help="FENDA Ditto loss weight for local model training", default=0.01 + ) + parser.add_argument( + "--use_adaptation", + action="store_true", + help="Whether or not the loss weight for model drift is adapted or remains fixed.", + default=False, + ) + args = parser.parse_args() + + config = load_config(args.config_path) + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Lambda: {args.lam}") + if args.use_adaptation: + log(INFO, "Adapting the loss weight for model drift via global model loss") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + main(config, args.server_address, args.lam, args.use_adaptation) diff --git a/research/cifar10/adaptive_pfl/mrmtl/client.py b/research/cifar10/adaptive_pfl/mrmtl/client.py new file mode 100644 index 000000000..84ff0c9ec --- /dev/null +++ b/research/cifar10/adaptive_pfl/mrmtl/client.py @@ -0,0 +1,162 @@ +import argparse +import os +from logging import INFO +from pathlib import Path +from typing import Optional, Sequence, Tuple + +import flwr as fl +import torch +import torch.nn as nn +from flwr.common.logger import log +from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.clients.mr_mtl_client import MrMtlClient +from fl4health.utils.config import narrow_dict_type +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import F1, Accuracy, Metric +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNet +from research.cifar10.preprocess import get_preprocessed_data + + +class CifarMrMtlClient(MrMtlClient): + def __init__( + self, + data_path: Path, + metrics: Sequence[Metric], + device: torch.device, + client_number: int, + learning_rate: float, + heterogeneity_level: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpointer: Optional[ClientCheckpointModule] = None, + ) -> None: + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpointer=checkpointer, + ) + self.client_number = client_number + self.heterogeneity_level = heterogeneity_level + self.learning_rate: float = learning_rate + + log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") + + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + batch_size = narrow_dict_type(config, "batch_size", int) + train_loader, val_loader, _ = get_preprocessed_data( + self.data_path, self.client_number, batch_size, self.heterogeneity_level + ) + return train_loader, val_loader + + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() + + def get_optimizer(self, config: Config) -> Optimizer: + return torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) + + def get_model(self, config: Config) -> nn.Module: + return ConvNet(in_channels=3, use_bn=False).to(self.device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Client Main") + parser.add_argument( + "--artifact_dir", + action="store", + type=str, + help="Path to save client artifacts such as logs and model checkpoints", + required=True, + ) + parser.add_argument( + "--dataset_dir", + action="store", + type=str, + help="Path to the preprocessed Cifar 10 Dataset", + required=True, + ) + parser.add_argument( + "--run_name", + action="store", + help="Name of the run, model checkpoints will be saved under a subfolder with this name", + required=True, + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address for the clients to communicate with the server through", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--client_number", + action="store", + type=int, + help="Number of the client for dataset loading", + required=True, + ) + parser.add_argument( + "--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1 + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--beta", + action="store", + type=float, + help="Heterogeneity level for the dataset", + required=True, + ) + args = parser.parse_args() + + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Learning Rate: {args.learning_rate}") + log(INFO, f"Beta: {args.beta}") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + # Adding extensive checkpointing for the client + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl" + pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" + post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" + post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" + checkpointer = ClientCheckpointModule( + pre_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + ], + post_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + ], + ) + + data_path = Path(args.dataset_dir) + client = CifarMrMtlClient( + data_path=data_path, + metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + heterogeneity_level=args.beta, + checkpointer=checkpointer, + ) + + fl.client.start_client(server_address=args.server_address, client=client.to_client()) + client.shutdown() diff --git a/research/cifar10/adaptive_pfl/mrmtl/config.yaml b/research/cifar10/adaptive_pfl/mrmtl/config.yaml new file mode 100644 index 000000000..323b2a693 --- /dev/null +++ b/research/cifar10/adaptive_pfl/mrmtl/config.yaml @@ -0,0 +1,7 @@ +# Parameters that describe server +n_server_rounds: 50 # The number of rounds to run FL + +# Parameters that describe clients +n_clients: 7 # The number of clients in the FL experiment +local_epochs: 1 # The number of epochs to complete for client +batch_size: 32 # The batch size for client training diff --git a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm new file mode 100644 index 000000000..77db6de76 --- /dev/null +++ b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm @@ -0,0 +1,179 @@ +#!/bin/bash + +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --mem=32G +#SBATCH --partition=a40 +#SBATCH --qos=m2 +#SBATCH --job-name=fl_five_fold_exp +#SBATCH --output=%j_%x.out +#SBATCH --error=%j_%x.err +#SBATCH --time=4:00:00 + +############################################### +# Usage: +# +# sbatch research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ \ +# client_side_learning_rate_value \ +# lambda value \ +# server_address \ +# client_beta \ +# adapt +# +# Example: +# sbatch research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm \ +# research/cifar10/adaptive_pfl/mrmtl/config.yaml \ +# research/cifar10/adaptive_pfl/mrmtl/hp_results/ \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ \ +# 0.0001 \ +# 0.01 \ +# 0.0.0.0:8080\ +# 0.1 \ +# "TRUE" +# +# Notes: +# 1) The sbatch command above should be run from the top level directory of the repository. +# 2) This example runs mrmtl. As such the data paths and python launch commands are hardcoded. If you want to change +# the example you run, you need to explicitly modify the code below. +# 3) The logging directories need to ALREADY EXIST. The script does not create them. +############################################### + +# Note: +# ntasks: Total number of processes to use across world +# ntasks-per-node: How many processes each node should create + +# Set NCCL options +# export NCCL_DEBUG=INFO +# NCCL backend to communicate between GPU workers is not provided in vector's cluster. +# Disable this option in slurm. +export NCCL_IB_DISABLE=1 + +if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ + [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then + echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}" + export NCCL_SOCKET_IFNAME=bond0 +fi + +# Process Inputs + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 +CLIENT_LR=$5 +LAM_VALUE=$6 +SERVER_ADDRESS=$7 +CLIENT_BETA=$8 +ADAPT=$9 + +# Create the artifact directory +mkdir "${ARTIFACT_DIR}" + +RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) +SEEDS=(2021 2022 2023 2024 2025) + +echo "Python Venv Path: ${VENV_PATH}" + +echo "World size: ${SLURM_NTASKS}" +echo "Number of nodes: ${SLURM_NNODES}" +NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +echo "GPUs per node: ${NUM_GPUs}" + +# Source the environment +source ${VENV_PATH}bin/activate +echo "Active Environment:" +which python + +for ((i=0; i<${#RUN_NAMES[@]}; i++)); +do + RUN_NAME="${RUN_NAMES[i]}" + SEED="${SEEDS[i]}" + # create the run directory + RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" + echo "Starting Run and logging artifcats at ${RUN_DIR}" + if [ -d "${RUN_DIR}" ] + then + # Directory already exists, we check if the done.out file exists + if [ -f "${RUN_DIR}done.out" ] + then + # Done file already exists so we skip this run + echo "Run already completed. Skipping Run." + continue + else + # Done file doesn't exists (assume pre-emption happened) + # Delete the partially finished contents and start over + echo "Run did not finished correctly. Re-running." + rm -r "${RUN_DIR}" + mkdir "${RUN_DIR}" + fi + else + # Directory doesn't exist yet, so we create it. + echo "Run directory does not exist. Creating it." + mkdir "${RUN_DIR}" + fi + + SERVER_OUTPUT_FILE="${RUN_DIR}server.out" + + # Start the server, divert the outputs to a server file + + echo "Server logging at: ${SERVER_OUTPUT_FILE}" + echo "Launching Server" + + if [[ ${ADAPT} == "TRUE" ]]; then + nohup python -m research.cifar10.adaptive_pfl.mrmtl.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + --use_adaptation \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + else + nohup python -m research.cifar10.adaptive_pfl.mrmtl.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + fi + + # Sleep for 20 seconds to allow the server to come up. + sleep 20 + + # Start n number of clients and divert the outputs to their own files + n_clients=7 + for (( c=0; c<${n_clients}; c++ )) + do + CLIENT_NAME="client_${c}" + echo "Launching ${CLIENT_NAME}" + + CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out" + echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}" + nohup python -m research.cifar10.adaptive_pfl.mrmtl.client \ + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + > ${CLIENT_LOG_PATH} 2>&1 & + done + + echo "FL Processes Running" + + wait + + # Create a file that verifies that the Run concluded properly + touch "${RUN_DIR}done.out" + echo "Finished FL Processes" + +done diff --git a/research/cifar10/adaptive_pfl/mrmtl/run_hp_sweep.sh b/research/cifar10/adaptive_pfl/mrmtl/run_hp_sweep.sh new file mode 100755 index 000000000..eafd9150b --- /dev/null +++ b/research/cifar10/adaptive_pfl/mrmtl/run_hp_sweep.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +############################################### +# Usage: +# +# ./research/cifar10/adaptive_pfl/mrmtl/run_hp_sweep.sh \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ +# +# Example: +# ./research/cifar10/adaptive_pfl/mrmtl/run_hp_sweep.sh \ +# research/cifar10/adaptive_pfl/mrmtl/config.yaml \ +# research/cifar10/adaptive_pfl/mrmtl \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ +# +# Notes: +# 1) The bash command above should be run from the top level directory of the repository. +############################################### + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 + +LR_VALUES=( 0.0001 0.001 0.01 0.1 ) +# Note: These values must correspond to values for the preprocessed CIFAR datasets +BETA_VALUES=( 0.1 0.5 5.0 ) +LAM_VALUES=( 0.001 1.0 ) +ADAPT_BOOL=( "TRUE" "FALSE" ) + +SERVER_PORT=8100 + +# Create sweep folder +SWEEP_DIRECTORY="${ARTIFACT_DIR}hp_sweep_results" +echo "Creating sweep folder at ${SWEEP_DIRECTORY}" +mkdir ${SWEEP_DIRECTORY} + +for BETA_VALUE in "${BETA_VALUES[@]}"; do + echo "Creating folder for beta ${BETA_VALUE}" + mkdir "${SWEEP_DIRECTORY}/beta_${BETA_VALUE}" + for LR_VALUE in "${LR_VALUES[@]}"; + do + for LAM_VALUE in "${LAM_VALUES[@]}"; + do + for ADAPT in "${ADAPT_BOOL[@]}"; + do + EXPERIMENT_NAME="lr_${LR_VALUE}_beta_${BETA_VALUE}_lam_${LAM_VALUE}" + if [[ ${ADAPT} == "TRUE" ]]; then + EXPERIMENT_NAME="${EXPERIMENT_NAME}_adapt" + fi + echo "Beginning Experiment ${EXPERIMENT_NAME}" + EXPERIMENT_DIRECTORY="${SWEEP_DIRECTORY}/beta_${BETA_VALUE}/${EXPERIMENT_NAME}/" + echo "Creating experiment folder ${EXPERIMENT_DIRECTORY}" + mkdir "${EXPERIMENT_DIRECTORY}" + SERVER_ADDRESS="0.0.0.0:${SERVER_PORT}" + echo "Server Address: ${SERVER_ADDRESS}" + SBATCH_COMMAND="research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm \ + ${SERVER_CONFIG_PATH} \ + ${EXPERIMENT_DIRECTORY} \ + ${DATASET_DIR} \ + ${VENV_PATH} \ + ${LR_VALUE} \ + ${LAM_VALUE} \ + ${SERVER_ADDRESS} \ + ${BETA_VALUE} \ + ${ADAPT}" + sbatch ${SBATCH_COMMAND} + ((SERVER_PORT=SERVER_PORT+1)) + done + done + done +done +echo Experiments Launched diff --git a/research/cifar10/adaptive_pfl/mrmtl/server.py b/research/cifar10/adaptive_pfl/mrmtl/server.py new file mode 100644 index 000000000..8ffca7c81 --- /dev/null +++ b/research/cifar10/adaptive_pfl/mrmtl/server.py @@ -0,0 +1,147 @@ +import argparse +from functools import partial +from logging import INFO +from typing import Any, Dict, Optional + +import flwr as fl +from flwr.common.logger import log +from flwr.common.typing import Config +from flwr.server.client_manager import ClientManager, SimpleClientManager +from flwr.server.strategy import Strategy + +from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint +from fl4health.utils.config import load_config +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from fl4health.utils.parameter_extraction import get_all_model_parameters +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNet +from research.cifar10.personal_server import PersonalServer + + +class PersonalMrMtlServer(PersonalServer): + """ + The PersonalServer class is used for FL approaches that only have a sense of a PERSONAL model that is checkpointed + and valid only on the client size of the FL training framework. FL approaches like APFL and FENDA fall under this + category. Each client will have its own model that is specific to its own training. Personal models may have + shared components but the full model is specific to each client. This is distinct from the + FlServerWithCheckpointing class which has a sense of a GLOBAL model checkpointed on the server-side that is + shared by all clients. + """ + + def __init__( + self, + client_manager: ClientManager, + strategy: Optional[Strategy] = None, + ) -> None: + assert isinstance( + strategy, FedAvgWithAdaptiveConstraint + ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" + # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with + # some globally shared weights. So we don't checkpoint a global model + super().__init__(client_manager, strategy) + + +def fit_config( + batch_size: int, + local_epochs: int, + n_server_rounds: int, + n_clients: int, + current_server_round: int, +) -> Config: + return { + "batch_size": batch_size, + "local_epochs": local_epochs, + "n_server_rounds": n_server_rounds, + "n_clients": n_clients, + "current_server_round": current_server_round, + } + + +def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_weight: bool) -> None: + # This function will be used to produce a config that is sent to each client to initialize their own environment + fit_config_fn = partial( + fit_config, + config["batch_size"], + config["local_epochs"], + config["n_server_rounds"], + config["n_clients"], + ) + + client_manager = SimpleClientManager() + # Initializing the model on the server side + model = ConvNet(in_channels=3, use_bn=False) + # Server performs simple FedAveraging as its server-side optimization strategy + strategy = FedAvgWithAdaptiveConstraint( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + # Server waits for min_available_clients before starting FL rounds + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + # We use the same fit config function, as nothing changes for eval + on_evaluate_config_fn=fit_config_fn, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=get_all_model_parameters(model), + initial_loss_weight=lam, + adapt_loss_weight=adapt_loss_weight, + ) + + server = PersonalMrMtlServer(client_manager=client_manager, strategy=strategy) + + fl.server.start_server( + server=server, + server_address=server_address, + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + + log(INFO, "Training Complete") + log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{server.best_aggregated_loss}") + + # Shutdown the server gracefully + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Server Main") + parser.add_argument( + "--config_path", + action="store", + type=str, + help="Path to configuration file.", + default="config.yaml", + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address to be used to communicate with the clients", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01 + ) + parser.add_argument( + "--use_adaptation", + action="store_true", + help="Whether or not the loss weight for model drift is adapted or remains fixed.", + default=False, + ) + args = parser.parse_args() + + config = load_config(args.config_path) + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Lambda: {args.lam}") + if args.use_adaptation: + log(INFO, "Adapting the loss weight for model drift via global model loss") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + main(config, args.server_address, args.lam, args.use_adaptation) diff --git a/research/cifar10/evaluate_on_test.py b/research/cifar10/evaluate_on_test.py new file mode 100644 index 000000000..a62183c99 --- /dev/null +++ b/research/cifar10/evaluate_on_test.py @@ -0,0 +1,748 @@ +import argparse +import copy +from logging import INFO +from pathlib import Path +from typing import Dict + +import torch +from flwr.common.logger import log + +from fl4health.utils.dataset import TensorDataset +from fl4health.utils.load_data import load_cifar10_test_data +from fl4health.utils.metrics import Accuracy +from fl4health.utils.sampler import DirichletLabelBasedSampler +from research.cifar10.preprocess import get_test_preprocessed_data +from research.cifar10.utils import ( + evaluate_cifar10_model, + get_all_run_folders, + get_metric_avg_std, + load_best_global_model, + load_eval_best_post_aggregation_local_model, + load_eval_best_pre_aggregation_local_model, + load_eval_last_post_aggregation_local_model, + load_eval_last_pre_aggregation_local_model, + load_last_global_model, + write_measurement_results, +) + +NUM_CLIENTS = 5 +BATCH_SIZE = 32 + + +def main( + artifact_dir: str, + dataset_dir: str, + use_partitioned_data: bool, + eval_write_path: str, + eval_best_pre_aggregation_local_models: bool, + eval_last_pre_aggregation_local_models: bool, + eval_best_post_aggregation_local_models: bool, + eval_last_post_aggregation_local_models: bool, + eval_best_global_model: bool, + eval_last_global_model: bool, + eval_over_aggregated_test_data: bool, + heterogeneity_level: float, + is_apfl: bool, +) -> None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + all_run_folder_dir = get_all_run_folders(artifact_dir) + test_results: Dict[str, float] = {} + metrics = [Accuracy("cifar10_accuracy")] + + all_pre_best_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + all_pre_last_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + all_post_best_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + all_post_last_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + + all_best_server_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + all_last_server_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + + if eval_over_aggregated_test_data: + all_pre_best_local_agg_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + all_pre_last_local_agg_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + all_post_best_local_agg_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + all_post_last_local_agg_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + + all_best_server_agg_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + all_last_server_agg_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} + + if eval_over_aggregated_test_data: + for client_number in range(NUM_CLIENTS): + if use_partitioned_data: + test_loader, _ = get_test_preprocessed_data( + Path(dataset_dir), client_number, BATCH_SIZE, heterogeneity_level + ) + else: + sampler = DirichletLabelBasedSampler( + list(range(10)), + sample_percentage=1.0 / NUM_CLIENTS, + beta=heterogeneity_level, + hash_key=client_number, + ) + test_loader, _ = load_cifar10_test_data(Path(dataset_dir), BATCH_SIZE, sampler=sampler) + assert isinstance(test_loader.dataset, TensorDataset), "Expected TensorDataset." + + if client_number == 0: + aggregated_dataset = copy.deepcopy(test_loader.dataset) + else: + assert aggregated_dataset.data is not None and test_loader.dataset.data is not None + aggregated_dataset.data = torch.cat((aggregated_dataset.data, test_loader.dataset.data)) + assert aggregated_dataset.targets is not None and test_loader.dataset.targets is not None + aggregated_dataset.targets = torch.cat((aggregated_dataset.targets, test_loader.dataset.targets)) + + aggregated_test_loader = torch.utils.data.DataLoader(aggregated_dataset, batch_size=BATCH_SIZE, shuffle=False) + aggregated_num_examples = len(aggregated_dataset) + + for client_number in range(NUM_CLIENTS): + if use_partitioned_data: + test_loader, num_examples = get_test_preprocessed_data( + Path(dataset_dir), client_number, BATCH_SIZE, heterogeneity_level + ) + else: + sampler = DirichletLabelBasedSampler( + list(range(10)), + sample_percentage=1.0 / NUM_CLIENTS, + beta=heterogeneity_level, + hash_key=client_number, + ) + test_loader, num_examples = load_cifar10_test_data(Path(dataset_dir), BATCH_SIZE, sampler=sampler) + + pre_best_local_test_metrics = [] + pre_last_local_test_metrics = [] + post_best_local_test_metrics = [] + post_last_local_test_metrics = [] + best_server_test_metrics = [] + last_server_test_metrics = [] + + if eval_over_aggregated_test_data: + pre_best_local_agg_test_metrics = [] + pre_last_local_agg_test_metrics = [] + post_best_local_agg_test_metrics = [] + post_last_local_agg_test_metrics = [] + best_server_agg_test_metrics = [] + last_server_agg_test_metrics = [] + + for run_folder_dir in all_run_folder_dir: + if eval_best_pre_aggregation_local_models: + local_model = load_eval_best_pre_aggregation_local_model(run_folder_dir, client_number) + local_run_metric = evaluate_cifar10_model(local_model, test_loader, metrics, device, is_apfl) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Best Pre-aggregation Local Model Test Performance: {local_run_metric}", + ) + + pre_best_local_test_metrics.append(local_run_metric) + all_pre_best_local_test_metrics[run_folder_dir] += ( + local_run_metric * num_examples["eval_set"] / aggregated_num_examples + ) + if eval_over_aggregated_test_data: + + agg_local_run_metric = evaluate_cifar10_model( + local_model, aggregated_test_loader, metrics, device, is_apfl + ) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Aggregated Best Pre-aggregation Local Model Test Performance: {agg_local_run_metric}", + ) + pre_best_local_agg_test_metrics.append(agg_local_run_metric) + all_pre_best_local_agg_test_metrics[run_folder_dir] += agg_local_run_metric / NUM_CLIENTS + + if eval_last_pre_aggregation_local_models: + local_model = load_eval_last_pre_aggregation_local_model(run_folder_dir, client_number) + local_run_metric = evaluate_cifar10_model(local_model, test_loader, metrics, device, is_apfl) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Last Pre-aggregation Local Model Test Performance: {local_run_metric}", + ) + pre_last_local_test_metrics.append(local_run_metric) + all_pre_last_local_test_metrics[run_folder_dir] += ( + local_run_metric * num_examples["eval_set"] / aggregated_num_examples + ) + + if eval_over_aggregated_test_data: + + agg_local_run_metric = evaluate_cifar10_model( + local_model, aggregated_test_loader, metrics, device, is_apfl + ) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Aggregated Last Pre-aggregation Local Model Test Performance: {agg_local_run_metric}", + ) + pre_last_local_agg_test_metrics.append(agg_local_run_metric) + all_pre_last_local_agg_test_metrics[run_folder_dir] += agg_local_run_metric / NUM_CLIENTS + + if eval_best_post_aggregation_local_models: + local_model = load_eval_best_post_aggregation_local_model(run_folder_dir, client_number) + local_run_metric = evaluate_cifar10_model(local_model, test_loader, metrics, device, is_apfl) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Best Post-aggregation Local Model Test Performance: {local_run_metric}", + ) + post_best_local_test_metrics.append(local_run_metric) + all_post_best_local_test_metrics[run_folder_dir] += ( + local_run_metric * num_examples["eval_set"] / aggregated_num_examples + ) + + if eval_over_aggregated_test_data: + + agg_local_run_metric = evaluate_cifar10_model( + local_model, aggregated_test_loader, metrics, device, is_apfl + ) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Aggregated Best Post-aggregation Local Model Test Performance: {agg_local_run_metric}", + ) + post_best_local_agg_test_metrics.append(agg_local_run_metric) + all_post_best_local_agg_test_metrics[run_folder_dir] += agg_local_run_metric / NUM_CLIENTS + + if eval_last_post_aggregation_local_models: + local_model = load_eval_last_post_aggregation_local_model(run_folder_dir, client_number) + local_run_metric = evaluate_cifar10_model(local_model, test_loader, metrics, device, is_apfl) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Last Post-aggregation Local Model Test Performance: {local_run_metric}", + ) + post_last_local_test_metrics.append(local_run_metric) + all_post_last_local_test_metrics[run_folder_dir] += ( + local_run_metric * num_examples["eval_set"] / aggregated_num_examples + ) + + if eval_over_aggregated_test_data: + + agg_local_run_metric = evaluate_cifar10_model( + local_model, aggregated_test_loader, metrics, device, is_apfl + ) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Aggregated Last Post-aggregation Local Model Test Performance: {agg_local_run_metric}", + ) + post_last_local_agg_test_metrics.append(agg_local_run_metric) + all_post_last_local_agg_test_metrics[run_folder_dir] += agg_local_run_metric / NUM_CLIENTS + + if eval_best_global_model: + server_model = load_best_global_model(run_folder_dir) + server_run_metric = evaluate_cifar10_model(server_model, test_loader, metrics, device, is_apfl) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Server Best Model Test Performance: {server_run_metric}", + ) + best_server_test_metrics.append(server_run_metric) + all_best_server_test_metrics[run_folder_dir] += ( + server_run_metric * num_examples["eval_set"] / aggregated_num_examples + ) + + if eval_over_aggregated_test_data: + + agg_server_run_metric = evaluate_cifar10_model( + server_model, aggregated_test_loader, metrics, device, is_apfl + ) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Aggregated Server Best Model Test Performance: {agg_server_run_metric}", + ) + best_server_agg_test_metrics.append(agg_server_run_metric) + all_best_server_agg_test_metrics[run_folder_dir] += agg_server_run_metric / NUM_CLIENTS + + if eval_last_global_model: + server_model = load_last_global_model(run_folder_dir) + server_run_metric = evaluate_cifar10_model(server_model, test_loader, metrics, device, is_apfl) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Server Last Model Test Performance: {server_run_metric}", + ) + last_server_test_metrics.append(server_run_metric) + all_last_server_test_metrics[run_folder_dir] += ( + server_run_metric * num_examples["eval_set"] / aggregated_num_examples + ) + + if eval_over_aggregated_test_data: + + agg_server_run_metric = evaluate_cifar10_model( + server_model, aggregated_test_loader, metrics, device, is_apfl + ) + log( + INFO, + f"Client Number {client_number}, Run folder: {run_folder_dir}: " + f"Aggregated Server Last Model Test Performance: {agg_server_run_metric}", + ) + last_server_agg_test_metrics.append(agg_server_run_metric) + all_last_server_agg_test_metrics[run_folder_dir] += agg_server_run_metric / NUM_CLIENTS + + # Write the results for each client + if eval_best_pre_aggregation_local_models: + avg_test_metric, std_test_metric = get_metric_avg_std(pre_best_local_test_metrics) + log( + INFO, + f"""Client {client_number} Pre-aggregation Best Model Average Test + Performance on own Data: {avg_test_metric}""", + ) + log( + INFO, + f"""Client {client_number} Pre-aggregation Best Model St. Dev. Test + Performance on own Data: {std_test_metric}""", + ) + test_results[f"client_{client_number}_pre_best_model_local_avg"] = avg_test_metric + test_results[f"client_{client_number}_pre_best_model_local_std"] = std_test_metric + + if eval_over_aggregated_test_data: + avg_test_metric, std_test_metric = get_metric_avg_std(pre_best_local_agg_test_metrics) + log( + INFO, + f"""Client {client_number} Pre-aggregation Best Model Average Test + Performance on Aggregated Data: {avg_test_metric}""", + ) + log( + INFO, + f"""Client {client_number} Pre-aggregation Best Model St. Dev. Test + Performance on Aggregated Data: {std_test_metric}""", + ) + test_results[f"agg_client_{client_number}_pre_best_model_local_avg"] = avg_test_metric + test_results[f"agg_client_{client_number}_pre_best_model_local_std"] = std_test_metric + + if eval_last_pre_aggregation_local_models: + avg_test_metric, std_test_metric = get_metric_avg_std(pre_last_local_test_metrics) + log( + INFO, + f"""Client {client_number} Pre-aggregation Last Model Average Test + Performance on own Data: {avg_test_metric}""", + ) + log( + INFO, + f"""Client {client_number} Pre-aggregation Last Model St. Dev. Test + Performance on own Data: {std_test_metric}""", + ) + test_results[f"client_{client_number}_pre_last_model_local_avg"] = avg_test_metric + test_results[f"client_{client_number}_pre_last_model_local_std"] = std_test_metric + if eval_over_aggregated_test_data: + avg_test_metric, std_test_metric = get_metric_avg_std(pre_last_local_agg_test_metrics) + log( + INFO, + f"""Client {client_number} Pre-aggregation Last Model Average Test + Performance on Aggregated Data: {avg_test_metric}""", + ) + log( + INFO, + f"""Client {client_number} Pre-aggregation Last Model St. Dev. Test + Performance on Aggregated Data: {std_test_metric}""", + ) + test_results[f"agg_client_{client_number}_pre_last_model_local_avg"] = avg_test_metric + test_results[f"agg_client_{client_number}_pre_last_model_local_std"] = std_test_metric + + if eval_best_post_aggregation_local_models: + avg_test_metric, std_test_metric = get_metric_avg_std(post_best_local_test_metrics) + log( + INFO, + f"""Client {client_number} Post-aggregation Best Model Average Test + Performance on own Data: {avg_test_metric}""", + ) + log( + INFO, + f"""Client {client_number} Post-aggregation Best Model St. Dev. Test + Performance on own Data: {std_test_metric}""", + ) + test_results[f"client_{client_number}_post_best_model_local_avg"] = avg_test_metric + test_results[f"client_{client_number}_post_best_model_local_std"] = std_test_metric + + if eval_over_aggregated_test_data: + avg_test_metric, std_test_metric = get_metric_avg_std(post_best_local_agg_test_metrics) + log( + INFO, + f"""Client {client_number} Post-aggregation Best Model Average Test + Performance on Aggregated Data: {avg_test_metric}""", + ) + log( + INFO, + f"""Client {client_number} Post-aggregation Best Model St. Dev. Test + Performance on Aggregated Data: {std_test_metric}""", + ) + test_results[f"agg_client_{client_number}_post_best_model_local_avg"] = avg_test_metric + test_results[f"agg_client_{client_number}_post_best_model_local_std"] = std_test_metric + + if eval_last_post_aggregation_local_models: + avg_test_metric, std_test_metric = get_metric_avg_std(post_last_local_test_metrics) + log( + INFO, + f"""Client {client_number} Post-aggregation Last Model Average Test + Performance on own Data: {avg_test_metric}""", + ) + log( + INFO, + f"""Client {client_number} Post-aggregation Last Model St. Dev. Test + Performance on own Data: {std_test_metric}""", + ) + test_results[f"client_{client_number}_post_last_model_local_avg"] = avg_test_metric + test_results[f"client_{client_number}_post_last_model_local_std"] = std_test_metric + + if eval_over_aggregated_test_data: + avg_test_metric, std_test_metric = get_metric_avg_std(post_last_local_agg_test_metrics) + log( + INFO, + f"""Client {client_number} Post-aggregation Last Model Average Test + Performance on Aggregated Data: {avg_test_metric}""", + ) + log( + INFO, + f"""Client {client_number} Post-aggregation Last Model St. Dev. Test + Performance on Aggregated Data: {std_test_metric}""", + ) + test_results[f"agg_client_{client_number}_post_last_model_local_avg"] = avg_test_metric + test_results[f"agg_client_{client_number}_post_last_model_local_std"] = std_test_metric + + if eval_best_global_model: + avg_server_test_global_metric, std_server_test_global_metric = get_metric_avg_std(best_server_test_metrics) + log( + INFO, + f"Server Best model Average Test Performance on Client {client_number} " + f"Data: {avg_server_test_global_metric}", + ) + log( + INFO, + f"Server Best model St. Dev. Test Performance on Client {client_number} " + f"Data: {std_server_test_global_metric}", + ) + test_results[f"server_best_model_client_{client_number}_avg"] = avg_server_test_global_metric + test_results[f"server_best_model_client_{client_number}_std"] = std_server_test_global_metric + + if eval_last_global_model: + avg_server_test_global_metric, std_server_test_global_metric = get_metric_avg_std(last_server_test_metrics) + log( + INFO, + f"Server Last model Average Test Performance on Client {client_number} " + f"Data: {avg_server_test_global_metric}", + ) + log( + INFO, + f"Server Last model St. Dev. Test Performance on Client {client_number} " + f"Data: {std_server_test_global_metric}", + ) + test_results[f"server_last_model_client_{client_number}_avg"] = avg_server_test_global_metric + test_results[f"server_last_model_client_{client_number}_std"] = std_server_test_global_metric + + if eval_over_aggregated_test_data: + if eval_best_global_model: + avg_server_test_global_metric, std_server_test_global_metric = get_metric_avg_std( + best_server_agg_test_metrics + ) + log( + INFO, + f"Server Best model Average Test Performance on Aggregated Client Data" + f"Data: {avg_server_test_global_metric}", + ) + log( + INFO, + f"Server Best model St. Dev. Test Performance on Aggregated Client Data" + f"Data: {std_server_test_global_metric}", + ) + test_results["agg_server_best_model_client_avg"] = avg_server_test_global_metric + test_results["agg_server_best_model_client_std"] = std_server_test_global_metric + + if eval_last_global_model: + avg_server_test_global_metric, std_server_test_global_metric = get_metric_avg_std( + last_server_agg_test_metrics + ) + log( + INFO, + f"Server Last model Average Test Performance on Aggregated Client Data" + f"Data: {avg_server_test_global_metric}", + ) + log( + INFO, + f"Server Last model St. Dev. Test Performance on Aggregated Client Data" + f"Data: {std_server_test_global_metric}", + ) + test_results["agg_server_last_model_client_avg"] = avg_server_test_global_metric + test_results["agg_server_last_model_client_std"] = std_server_test_global_metric + + if eval_best_pre_aggregation_local_models: + all_avg_test_metric, all_std_test_metric = get_metric_avg_std(list(all_pre_best_local_test_metrics.values())) + test_results["avg_pre_best_local_model_avg_across_clients"] = all_avg_test_metric + test_results["std_pre_best_local_model_avg_across_clients"] = all_std_test_metric + log(INFO, f"Avg Pre-aggregation Best Local Model Test Performance Over all clients: {all_avg_test_metric}") + log( + INFO, + f"Std. Dev. Pre-aggregation Best Local Model Test Performance Over all clients: {all_std_test_metric}", + ) + if eval_over_aggregated_test_data: + all_avg_test_metric, all_std_test_metric = get_metric_avg_std( + list(all_pre_best_local_agg_test_metrics.values()) + ) + test_results["agg_avg_pre_best_local_model_avg_across_clients"] = all_avg_test_metric + test_results["agg_std_pre_best_local_model_avg_across_clients"] = all_std_test_metric + log( + INFO, + f"""Avg Pre-aggregation Best Local Model Test + Performance Over Aggregated clients: {all_avg_test_metric}""", + ) + log( + INFO, + f"""Std. Dev. Pre-aggregation Best Local Model Test + Performance Over Aggregated clients: {all_std_test_metric}""", + ) + + if eval_last_pre_aggregation_local_models: + all_avg_test_metric, all_std_test_metric = get_metric_avg_std(list(all_pre_last_local_test_metrics.values())) + test_results["avg_pre_last_local_model_avg_across_clients"] = all_avg_test_metric + test_results["std_pre_last_local_model_avg_across_clients"] = all_std_test_metric + log(INFO, f"Avg Pre-aggregation Last Local Model Test Performance Over all clients: {all_avg_test_metric}") + log( + INFO, + f"Std. Dev. Pre-aggregation Last Local Model Test Performance Over all clients: {all_std_test_metric}", + ) + if eval_over_aggregated_test_data: + all_avg_test_metric, all_std_test_metric = get_metric_avg_std( + list(all_pre_last_local_agg_test_metrics.values()) + ) + test_results["agg_avg_pre_last_local_model_avg_across_clients"] = all_avg_test_metric + test_results["agg_std_pre_last_local_model_avg_across_clients"] = all_std_test_metric + log( + INFO, + f"""Avg Pre-aggregation Last Local Model Test + Performance Over Aggregated clients: {all_avg_test_metric}""", + ) + log( + INFO, + f"""Std. Dev. Pre-aggregation Last Local Model Test + Performance Over Aggregated clients: {all_std_test_metric}""", + ) + + if eval_best_post_aggregation_local_models: + all_avg_test_metric, all_std_test_metric = get_metric_avg_std(list(all_post_best_local_test_metrics.values())) + test_results["avg_post_best_local_model_avg_across_clients"] = all_avg_test_metric + test_results["std_post_best_local_model_avg_across_clients"] = all_std_test_metric + log(INFO, f"Avg Post-aggregation Best Local Model Test Performance Over all clients: {all_avg_test_metric}") + log( + INFO, + f"Std. Dev. Post-aggregation Best Local Model Test Performance Over all clients: {all_std_test_metric}", + ) + if eval_over_aggregated_test_data: + all_avg_test_metric, all_std_test_metric = get_metric_avg_std( + list(all_post_best_local_agg_test_metrics.values()) + ) + test_results["agg_avg_post_best_local_model_avg_across_clients"] = all_avg_test_metric + test_results["agg_std_post_best_local_model_avg_across_clients"] = all_std_test_metric + log( + INFO, + f"""Avg Post-aggregation Best Local Model Test + Performance Over Aggregated clients: {all_avg_test_metric}""", + ) + log( + INFO, + f"""Std. Dev. Post-aggregation Best Local Model Test + Performance Over Aggregated clients: {all_std_test_metric}""", + ) + + if eval_last_post_aggregation_local_models: + all_avg_test_metric, all_std_test_metric = get_metric_avg_std(list(all_post_last_local_test_metrics.values())) + test_results["avg_post_last_local_model_avg_across_clients"] = all_avg_test_metric + test_results["std_post_last_local_model_avg_across_clients"] = all_std_test_metric + log(INFO, f"Avg Post-aggregation Last Local Model Test Performance Over all clients: {all_avg_test_metric}") + log( + INFO, + f"Std. Dev. Post-aggregation Last Local Model Test Performance Over all clients: {all_std_test_metric}", + ) + if eval_over_aggregated_test_data: + all_avg_test_metric, all_std_test_metric = get_metric_avg_std( + list(all_post_last_local_agg_test_metrics.values()) + ) + test_results["agg_avg_post_last_local_model_avg_across_clients"] = all_avg_test_metric + test_results["agg_std_post_last_local_model_avg_across_clients"] = all_std_test_metric + log( + INFO, + f"""Avg Post-aggregation Last Local Model Test + Performance Over Aggregated clients: {all_avg_test_metric}""", + ) + log( + INFO, + f"""Std. Dev. Post-aggregation Last Local Model Test + Performance Over Aggregated clients: {all_std_test_metric}""", + ) + + if eval_best_global_model: + all_server_avg_test_metric, all_server_std_test_metric = get_metric_avg_std( + list(all_best_server_test_metrics.values()) + ) + test_results["avg_best_server_model_avg_across_clients"] = all_server_avg_test_metric + test_results["std_best_server_model_avg_across_clients"] = all_server_std_test_metric + log(INFO, f"Avg. Best Server Model Test Performance Over all clients: {all_server_avg_test_metric}") + log(INFO, f"Std. Dev. Best Server Model Test Performance Over all clients: {all_server_std_test_metric}") + + if eval_over_aggregated_test_data: + all_server_avg_test_metric, all_server_std_test_metric = get_metric_avg_std( + list(all_best_server_agg_test_metrics.values()) + ) + test_results["agg_avg_best_server_model_avg_across_clients"] = all_server_avg_test_metric + test_results["agg_std_best_server_model_avg_across_clients"] = all_server_std_test_metric + log( + INFO, + f"""Avg. Best Server Model Test Performance Over Aggregated + clients: {all_server_avg_test_metric}""", + ) + log( + INFO, + f"""Std. Dev. Best Server Model Test Performance Over Aggregated + clients: {all_server_std_test_metric}""", + ) + + if eval_last_global_model: + all_server_avg_test_metric, all_server_std_test_metric = get_metric_avg_std( + list(all_last_server_test_metrics.values()) + ) + test_results["avg_last_server_model_avg_across_clients"] = all_server_avg_test_metric + test_results["std_last_server_model_avg_across_clients"] = all_server_std_test_metric + log(INFO, f"Avg. Last Server Model Test Performance Over all clients: {all_server_avg_test_metric}") + log(INFO, f"Std. Dev. Last Server Model Test Performance Over all clients: {all_server_std_test_metric}") + + if eval_over_aggregated_test_data: + all_server_avg_test_metric, all_server_std_test_metric = get_metric_avg_std( + list(all_last_server_agg_test_metrics.values()) + ) + test_results["agg_avg_last_server_model_avg_across_clients"] = all_server_avg_test_metric + test_results["agg_std_last_server_model_avg_across_clients"] = all_server_std_test_metric + log( + INFO, + f"""Avg. Last Server Model Test Performance Over Aggregated + clients: {all_server_avg_test_metric}""", + ) + log( + INFO, + f"""Std. Dev. Last Server Model Test Performance Over Aggregated + clients: {all_server_std_test_metric}""", + ) + + write_measurement_results(eval_write_path, test_results) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate Trained Models on Test Data") + parser.add_argument( + "--artifact_dir", + action="store", + type=str, + help="Path to saved model artifacts to be evaluated", + required=True, + ) + parser.add_argument( + "--dataset_dir", + action="store", + type=str, + help="Path to the preprocessed Cifar10 Dataset (ex. path/to/cifar10)", + required=True, + ) + parser.add_argument( + "--use_partitioned_data", + action="store_true", + help="Use preprocessed partitioned data for training, validation and testing", + default=True, + ) + parser.add_argument( + "--eval_write_path", + action="store", + type=str, + help="Path to write the evaluation results file", + required=True, + ) + parser.add_argument( + "--eval_best_global_model", + action="store_true", + help="boolean to indicate whether to search for and evaluate best server model in addition to client models", + ) + parser.add_argument( + "--eval_last_global_model", + action="store_true", + help="boolean to indicate whether to search for and evaluate last server model in addition to client models", + ) + parser.add_argument( + "--eval_best_pre_aggregation_local_models", + action="store_true", + help="""boolean to indicate whether to search for and evaluate best pre-aggregation local models in addition + to the server model""", + ) + parser.add_argument( + "--eval_best_post_aggregation_local_models", + action="store_true", + help="""boolean to indicate whether to search for and evaluate best post-aggregation local models in addition + to the server model""", + ) + parser.add_argument( + "--eval_last_pre_aggregation_local_models", + action="store_true", + help="""boolean to indicate whether to search for and evaluate last pre-aggregation local models in addition + to the server model""", + ) + parser.add_argument( + "--eval_last_post_aggregation_local_models", + action="store_true", + help="""boolean to indicate whether to search for and evaluate last post-aggregation local models in addition + to the server model""", + ) + parser.add_argument( + "--eval_over_aggregated_test_data", + action="store_true", + help="""boolean to indicate whether to evaluate all the models on the over-aggregated test data as well as + client specific data""", + ) + + parser.add_argument( + "--is_apfl", + action="store_true", + help="boolean to indicate whether we're evaluating an APFL model or not, as those model have special args", + ) + parser.add_argument( + "--beta", + action="store", + type=float, + help="Heterogeneity level for the dataset", + required=False, + default=0.1, + ) + + args = parser.parse_args() + log(INFO, f"Artifact Directory: {args.artifact_dir}") + log(INFO, f"Dataset Directory: {args.dataset_dir}") + log(INFO, f"Eval Write Path: {args.eval_write_path}") + + log(INFO, f"Run Best Global Model: {args.eval_best_global_model}") + log(INFO, f"Run Last Global Model: {args.eval_last_global_model}") + log(INFO, f"Run Best Pre-aggregation Local Model: {args.eval_best_pre_aggregation_local_models}") + log(INFO, f"Run Last Pre-aggregation Local Model: {args.eval_last_pre_aggregation_local_models}") + log(INFO, f"Run Best Post-aggregation Local Model: {args.eval_best_post_aggregation_local_models}") + log(INFO, f"Run Last Post-aggregation Local Model: {args.eval_last_post_aggregation_local_models}") + log(INFO, f"Run Eval Over Aggregated Test Data: {args.eval_over_aggregated_test_data}") + + log(INFO, f"Heterogeneity level for the dataset: {args.beta}") + log(INFO, f"Is APFL Run: {args.is_apfl}") + + assert ( + args.eval_best_global_model + or args.eval_last_global_model + or args.eval_best_pre_aggregation_local_models + or args.eval_last_pre_aggregation_local_models + or args.eval_best_post_aggregation_local_models + or args.eval_last_post_aggregation_local_models + ) + main( + args.artifact_dir, + args.dataset_dir, + args.use_partitioned_data, + args.eval_write_path, + args.eval_best_pre_aggregation_local_models, + args.eval_last_pre_aggregation_local_models, + args.eval_best_post_aggregation_local_models, + args.eval_last_post_aggregation_local_models, + args.eval_best_global_model, + args.eval_last_global_model, + args.eval_over_aggregated_test_data, + args.beta, + args.is_apfl, + ) diff --git a/research/cifar10/find_best_hp.py b/research/cifar10/find_best_hp.py new file mode 100644 index 000000000..b44a63ad1 --- /dev/null +++ b/research/cifar10/find_best_hp.py @@ -0,0 +1,60 @@ +import argparse +import os +from logging import INFO +from typing import List, Optional + +import numpy as np +from flwr.common.logger import log + + +def get_hp_folders(hp_sweep_dir: str) -> List[str]: + paths_in_hp_sweep_dir = [os.path.join(hp_sweep_dir, contents) for contents in os.listdir(hp_sweep_dir)] + return [hp_folder for hp_folder in paths_in_hp_sweep_dir if os.path.isdir(hp_folder)] + + +def get_run_folders(hp_dir: str) -> List[str]: + run_folder_names = [folder_name for folder_name in os.listdir(hp_dir) if "Run" in folder_name] + return [os.path.join(hp_dir, run_folder_name) for run_folder_name in run_folder_names] + + +def get_weighted_loss_from_server_log(run_folder_path: str) -> float: + server_log_path = os.path.join(run_folder_path, "server.out") + with open(server_log_path, "r") as handle: + files_lines = handle.readlines() + line_to_convert = files_lines[-1].strip() + return float(line_to_convert) + + +def main(hp_sweep_dir: str) -> None: + hp_folders = get_hp_folders(hp_sweep_dir) + best_avg_loss: Optional[float] = None + best_folder = "" + for hp_folder in hp_folders: + run_folders = get_run_folders(hp_folder) + hp_losses = [] + for run_folder in run_folders: + run_loss = get_weighted_loss_from_server_log(run_folder) + hp_losses.append(run_loss) + current_avg_loss = float(np.mean(hp_losses)) + if best_avg_loss is None or current_avg_loss <= best_avg_loss: + log(INFO, f"Current Loss: {current_avg_loss} is lower than Best Loss: {best_avg_loss}") + log(INFO, f"Best Folder: {hp_folder}, Previous Best: {best_folder}") + best_avg_loss = current_avg_loss + best_folder = hp_folder + log(INFO, f"Best Loss: {best_avg_loss}") + log(INFO, f"Best Folder: {best_folder}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate Holdout Global") + parser.add_argument( + "--hp_sweep_dir", + action="store", + type=str, + help="Path to the artifacts of the hyper-parameter sweep script", + required=True, + ) + args = parser.parse_args() + + log(INFO, f"Hyperparameter Sweep Directory: {args.hp_sweep_dir}") + main(args.hp_sweep_dir) diff --git a/research/cifar10/model.py b/research/cifar10/model.py index 287d2a570..ef17fda5e 100644 --- a/research/cifar10/model.py +++ b/research/cifar10/model.py @@ -2,9 +2,12 @@ import torch.nn as nn from torch.nn import BatchNorm2d, Conv2d, Flatten, Linear, MaxPool2d, Module, ReLU +from fl4health.model_bases.fenda_base import FendaModel +from fl4health.model_bases.parallel_split_models import ParallelFeatureJoinMode, ParallelSplitHeadModule +from fl4health.model_bases.sequential_split_models import SequentiallySplitModel -class ConvNet(Module): +class ConvNet(Module): def __init__( self, in_channels: int, @@ -45,3 +48,104 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc2(x) return x + + +class ConvNetFeatureExtractor(Module): + def __init__( + self, + in_channels: int, + use_bn: bool = True, + ) -> None: + super().__init__() + + self.conv1 = Conv2d(in_channels, 32, 5, padding=2) + self.conv2 = Conv2d(32, 64, 5, padding=2) + self.use_bn = use_bn + if use_bn: + self.bn1 = BatchNorm2d(32) + self.bn2 = BatchNorm2d(64) + + self.relu = ReLU(inplace=True) + self.maxpool = MaxPool2d(2) + self.flatten = Flatten() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + x = self.bn1(self.conv1(x)) if self.use_bn else self.conv1(x) + x = self.maxpool(self.relu(x)) + x = self.bn2(self.conv2(x)) if self.use_bn else self.conv2(x) + x = self.maxpool(self.relu(x)) + x = self.flatten(x) + + return x + + +class ConvNetClassifier(ParallelSplitHeadModule): + def __init__( + self, + join_mode: ParallelFeatureJoinMode, + h: int = 32, + w: int = 32, + hidden: int = 2048, + class_num: int = 10, + dropout: float = 0.0, + ) -> None: + super().__init__(join_mode) + + # Times 2 because we'll be concatenating the inputs of two feature extractors + self.fc1 = Linear((h // 2 // 2) * (w // 2 // 2) * 64 * 2, hidden) + self.fc2 = Linear(hidden, class_num) + + self.relu = ReLU(inplace=True) + self.dropout_layer = nn.Dropout(p=dropout) + + def parallel_output_join(self, local_tensor: torch.Tensor, global_tensor: torch.Tensor) -> torch.Tensor: + # Assuming tensors are "batch first", we concatenate along the channel dimension + return torch.concat([local_tensor, global_tensor], dim=1) + + def head_forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.dropout_layer(x) + x = self.relu(self.fc1(x)) + x = self.dropout_layer(x) + x = self.fc2(x) + + return x + + +class ConvNetFendaModel(FendaModel): + def __init__( + self, + in_channels: int, + h: int = 32, + w: int = 32, + hidden: int = 2048, + class_num: int = 10, + use_bn: bool = True, + dropout: float = 0.0, + ) -> None: + # FedIXI out_channels_first_layer = 8 is the Baseline model default. So we use it here. The monte carlo dropout + # is also set to 0 by default for FedIXI + local_module = ConvNetFeatureExtractor(in_channels, use_bn) + global_module = ConvNetFeatureExtractor(in_channels, use_bn) + model_head = ConvNetClassifier( + ParallelFeatureJoinMode.CONCATENATE, h=h, w=w, hidden=hidden, class_num=class_num, dropout=dropout + ) + super().__init__(local_module=local_module, global_module=global_module, model_head=model_head) + + +class ConvNetFendaDittoGlobalModel(SequentiallySplitModel): + def __init__( + self, + in_channels: int, + h: int = 32, + w: int = 32, + hidden: int = 2048, + class_num: int = 10, + use_bn: bool = True, + dropout: float = 0.0, + ) -> None: + base_module = ConvNetFeatureExtractor(in_channels, use_bn) + head_module = ConvNetClassifier( + ParallelFeatureJoinMode.CONCATENATE, h=h, w=w, hidden=hidden, class_num=class_num, dropout=dropout + ) + super().__init__(base_module, head_module, flatten_features=False) diff --git a/research/cifar10/personal_server.py b/research/cifar10/personal_server.py new file mode 100644 index 000000000..1d4758d3b --- /dev/null +++ b/research/cifar10/personal_server.py @@ -0,0 +1,63 @@ +from logging import INFO +from typing import Dict, Optional, Tuple + +from flwr.common.logger import log +from flwr.common.typing import Scalar +from flwr.server.client_manager import ClientManager +from flwr.server.server import EvaluateResultsAndFailures +from flwr.server.strategy import Strategy + +from fl4health.server.base_server import FlServer + + +class PersonalServer(FlServer): + """ + The PersonalServer class is used for FL approaches that only have a sense of a PERSONAL model that is checkpointed + and valid only on the client size of the FL training framework. FL approaches like APFL and FENDA fall under this + category. Each client will have its own model that is specific to its own training. Personal models may have + shared components but the full model is specific to each client. This is distinct from the + FlServerWithCheckpointing class which has a sense of a GLOBAL model checkpointed on the server-side that is + shared by all clients. + """ + + def __init__( + self, + client_manager: ClientManager, + strategy: Optional[Strategy] = None, + ) -> None: + # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with + # some globally shared weights. So we don't checkpoint a global model + super().__init__(client_manager, strategy, checkpointer=None) + self.best_aggregated_loss: Optional[float] = None + + def evaluate_round( + self, + server_round: int, + timeout: Optional[float], + ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + # loss_aggregated is the aggregated validation per step loss + # aggregated over each client (weighted by num examples) + eval_round_results = super().evaluate_round(server_round, timeout) + assert eval_round_results is not None + loss_aggregated, metrics_aggregated, (results, failures) = eval_round_results + assert loss_aggregated is not None + + if self.best_aggregated_loss: + if self.best_aggregated_loss >= loss_aggregated: + log( + INFO, + f"Best Aggregated Loss: {self.best_aggregated_loss} " + f"is larger than current aggregated loss: {loss_aggregated}", + ) + self.best_aggregated_loss = loss_aggregated + else: + log( + INFO, + f"Best Aggregated Loss: {self.best_aggregated_loss} " + f"is smaller than current aggregated loss: {loss_aggregated}", + ) + else: + log(INFO, f"Saving Best Aggregated Loss: {loss_aggregated} as it is currently None") + self.best_aggregated_loss = loss_aggregated + + return loss_aggregated, metrics_aggregated, (results, failures) diff --git a/research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess.slrm b/research/cifar10/pfl_preprocess_scripts/preprocess.slrm similarity index 100% rename from research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess.slrm rename to research/cifar10/pfl_preprocess_scripts/preprocess.slrm diff --git a/research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess_all.sh b/research/cifar10/pfl_preprocess_scripts/preprocess_all.sh similarity index 92% rename from research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess_all.sh rename to research/cifar10/pfl_preprocess_scripts/preprocess_all.sh index bc28cd567..526e131e5 100644 --- a/research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess_all.sh +++ b/research/cifar10/pfl_preprocess_scripts/preprocess_all.sh @@ -26,7 +26,7 @@ do CLIENT_ERROR_LOGS="cifar_preprocess_log_${SEEDS}_${BETAS}_${NUM_PARTITIONS}.err" SBATCH_COMMAND="--job-name=cifar_preprocess_${BETA} --output=${CLIENT_OUT_LOGS} --error=${CLIENT_ERROR_LOGS} \ - research/cifar10/adaptive_pfl/data_preprocess_scripts/preprocess.slrm \ + research/cifar10/pfl_preprocess_scripts/preprocess.slrm \ ${VENV_PATH} ${ORIGINAL_DATA_DIR} ${DESTINATION_DIRS[index]} ${SEEDS[index]} ${BETAS[index]} \ ${NUM_PARTITIONS[index]} ${DESTINATION_DIRS[index]}" \ diff --git a/research/cifar10/utils.py b/research/cifar10/utils.py new file mode 100644 index 000000000..b2a823fe8 --- /dev/null +++ b/research/cifar10/utils.py @@ -0,0 +1,94 @@ +import os +from typing import Dict, List, Sequence, Tuple + +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader + +from fl4health.utils.metrics import Metric, MetricManager + + +def get_all_run_folders(artifact_dir: str) -> List[str]: + run_folder_names = [folder_name for folder_name in os.listdir(artifact_dir) if "Run" in folder_name] + return [os.path.join(artifact_dir, run_folder_name) for run_folder_name in run_folder_names] + + +def load_best_global_model(run_folder_dir: str) -> nn.Module: + model_checkpoint_path = os.path.join(run_folder_dir, "server_best_model.pkl") + model = torch.load(model_checkpoint_path) + return model + + +def load_last_global_model(run_folder_dir: str) -> nn.Module: + model_checkpoint_path = os.path.join(run_folder_dir, "server_last_model.pkl") + model = torch.load(model_checkpoint_path) + return model + + +def get_metric_avg_std(metrics: List[float]) -> Tuple[float, float]: + mean = float(np.mean(metrics)) + std = float(np.std(metrics, ddof=1)) + return mean, std + + +def write_measurement_results(eval_write_path: str, results: Dict[str, float]) -> None: + with open(eval_write_path, "w") as f: + for key, metric_value in results.items(): + f.write(f"{key}: {metric_value}\n") + + +def evaluate_cifar10_model( + model: nn.Module, dataset: DataLoader, metrics: Sequence[Metric], device: torch.device, is_apfl: bool +) -> float: + meter = evaluate_model_on_dataset(model, dataset, metrics, device, is_apfl) + + computed_metrics = meter.compute() + assert "test_meter - prediction - cifar10_accuracy" in computed_metrics + accuracy = computed_metrics["test_meter - prediction - cifar10_accuracy"] + assert isinstance(accuracy, float) + return accuracy + + +def load_eval_best_pre_aggregation_local_model(run_folder_dir: str, client_number: int) -> nn.Module: + model_checkpoint_path = os.path.join(run_folder_dir, f"pre_aggregation_client_{client_number}_best_model.pkl") + model = torch.load(model_checkpoint_path) + return model + + +def load_eval_last_pre_aggregation_local_model(run_folder_dir: str, client_number: int) -> nn.Module: + model_checkpoint_path = os.path.join(run_folder_dir, f"pre_aggregation_client_{client_number}_last_model.pkl") + model = torch.load(model_checkpoint_path) + return model + + +def load_eval_best_post_aggregation_local_model(run_folder_dir: str, client_number: int) -> nn.Module: + model_checkpoint_path = os.path.join(run_folder_dir, f"post_aggregation_client_{client_number}_best_model.pkl") + model = torch.load(model_checkpoint_path) + return model + + +def load_eval_last_post_aggregation_local_model(run_folder_dir: str, client_number: int) -> nn.Module: + model_checkpoint_path = os.path.join(run_folder_dir, f"post_aggregation_client_{client_number}_last_model.pkl") + model = torch.load(model_checkpoint_path) + return model + + +def evaluate_model_on_dataset( + model: nn.Module, dataset: DataLoader, metrics: Sequence[Metric], device: torch.device, is_apfl: bool +) -> MetricManager: + model.to(device).eval() + meter = MetricManager(metrics, "test_meter") + + with torch.no_grad(): + for input, target in dataset: + input, target = input.to(device), target.to(device) + if is_apfl: + preds = model(input)["personal"] + else: + preds = model(input) + if isinstance(preds, tuple): + preds = preds[0] + preds = preds if isinstance(preds, dict) else {"prediction": preds} + meter.update(preds, target) + return meter From d583bceaa896958a9b47ce26fd8963bce7e06a68 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:10:00 -0400 Subject: [PATCH 04/19] Changes to facilitate better packaging of loss components into information communicated with the server. For now this takes the form of packing loss values into the metrics dictionary if requested via the config from the server. --- examples/feddg_ga_example/README.md | 1 + examples/feddg_ga_example/config.yaml | 2 + examples/feddg_ga_example/server.py | 2 +- fl4health/clients/basic_client.py | 79 ++-- fl4health/clients/ditto_client.py | 4 +- fl4health/clients/evaluate_client.py | 2 +- fl4health/clients/fedrep_client.py | 2 +- fl4health/clients/flash_client.py | 8 +- fl4health/clients/mr_mtl_client.py | 4 +- fl4health/clients/nnunet_client.py | 2 +- fl4health/server/base_server.py | 12 +- .../fedavg_with_adaptive_constraint.py | 6 +- .../{feddg_ga_strategy.py => feddg_ga.py} | 62 ++- .../feddg_ga_with_adaptive_constraint.py | 268 ++++++++++++ fl4health/utils/metrics.py | 8 +- tests/clients/test_basic_client.py | 24 +- tests/server/test_base_server.py | 20 +- .../smoke_tests/feddg_ga_client_metrics.json | 6 +- tests/strategies/test_feddg_ga_strategy.py | 101 ++++- .../test_feddg_ga_with_adapt_constraint.py | 411 ++++++++++++++++++ 20 files changed, 922 insertions(+), 102 deletions(-) rename fl4health/strategies/{feddg_ga_strategy.py => feddg_ga.py} (87%) create mode 100644 fl4health/strategies/feddg_ga_with_adaptive_constraint.py create mode 100644 tests/strategies/test_feddg_ga_with_adapt_constraint.py diff --git a/examples/feddg_ga_example/README.md b/examples/feddg_ga_example/README.md index adf3956f7..7f677ad7e 100644 --- a/examples/feddg_ga_example/README.md +++ b/examples/feddg_ga_example/README.md @@ -29,6 +29,7 @@ from the FL4Health directory. The following arguments must be present in the spe * `batch_size`: size of the batches each client will train on * `n_server_rounds`: The number of rounds to run FL * `evaluate_after_fit`: Should be set to `True`. Performs an evaluation at the end of each client's fit round. +* `pack_losses_with_val_metrics`: Should be set to `True`. Includes validation losses with metrics calculations ## Starting Clients diff --git a/examples/feddg_ga_example/config.yaml b/examples/feddg_ga_example/config.yaml index 992143f44..71fb1e2cb 100644 --- a/examples/feddg_ga_example/config.yaml +++ b/examples/feddg_ga_example/config.yaml @@ -7,3 +7,5 @@ local_steps: 5 # The number of local steps (one per batch) to complete for clien batch_size: 128 # The batch size for client training # Evaluates model immediately after local training on the validation set (in addition to the training set) evaluate_after_fit: True +# Packs the measured validation losses with the metrics (required for fed-dgga) +pack_losses_with_val_metrics: True diff --git a/examples/feddg_ga_example/server.py b/examples/feddg_ga_example/server.py index 4cca07800..d78226247 100644 --- a/examples/feddg_ga_example/server.py +++ b/examples/feddg_ga_example/server.py @@ -10,7 +10,7 @@ from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager from fl4health.model_bases.apfl_base import ApflModule from fl4health.server.base_server import FlServer -from fl4health.strategies.feddg_ga_strategy import FedDgGaStrategy +from fl4health.strategies.feddg_ga import FedDgGaStrategy from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 70c8a06ee..e69170d00 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -24,7 +24,7 @@ from fl4health.reporting.metrics import MetricsReporter from fl4health.utils.config import narrow_dict_type, narrow_dict_type_and_set_attribute from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType, TrainingLosses -from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, Metric, MetricManager +from fl4health.utils.metrics import TEST_NUM_EXAMPLES_KEY, Metric, MetricManager, MetricPrefix from fl4health.utils.random import generate_hash from fl4health.utils.typing import LogLevel, TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -213,7 +213,7 @@ def shutdown(self) -> None: self.metrics_reporter.add_to_metrics({"shutdown": datetime.datetime.now()}) - def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool]: + def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool, bool]: """ Method to ensure the required keys are present in config and extracts values to be returned. @@ -247,8 +247,13 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N except ValueError: evaluate_after_fit = False + try: + pack_losses_with_val_metrics = narrow_dict_type(config, "pack_losses_with_val_metrics", bool) + except ValueError: + pack_losses_with_val_metrics = False + # Either local epochs or local steps is none based on what key is passed in the config - return local_epochs, local_steps, current_server_round, evaluate_after_fit + return local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: """ @@ -267,7 +272,9 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict Raises: ValueError: If local_steps or local_epochs is not specified in config. """ - local_epochs, local_steps, current_server_round, evaluate_after_fit = self.process_config(config) + local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics = ( + self.process_config(config) + ) if not self.initialized: self.setup_client(config) @@ -300,7 +307,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict # Check if we should run an evaluation with validation data after fit # (for example, this is used by FedDGGA) if self._should_evaluate_after_fit(evaluate_after_fit): - validation_loss, validation_metrics = self.evaluate_after_fit() + validation_loss, validation_metrics = self.validate(pack_losses_with_val_metrics) metrics.update(validation_metrics) # We perform a pre-aggregation checkpoint if applicable self._maybe_checkpoint(validation_loss, validation_metrics, CheckpointMode.PRE_AGGREGATION) @@ -326,21 +333,6 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict metrics, ) - def evaluate_after_fit(self) -> Tuple[float, Dict[str, Scalar]]: - """ - Run self.validate right after fit to collect metrics on the local model against validation data. - - Returns: (Dict[str, Scalar]) a dictionary with the metrics. - - """ - loss, metric_values = self.validate() - # The computed loss value is packed into the metrics dictionary, perhaps for use on the server-side - metrics_after_fit = { - **metric_values, # type: ignore - "val - loss": loss, - } - return loss, metrics_after_fit - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: """ Evaluates the model on the validation set, and test set (if defined). @@ -357,13 +349,19 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di self.setup_client(config) current_server_round = narrow_dict_type(config, "current_server_round", int) + + try: + pack_losses_with_val_metrics = narrow_dict_type(config, "pack_losses_with_val_metrics", bool) + except ValueError: + pack_losses_with_val_metrics = False + self.metrics_reporter.add_to_metrics_at_round( current_server_round, data={"evaluate_start": datetime.datetime.now()}, ) self.set_parameters(parameters, config, fitting_round=False) - loss, metrics = self.validate() + loss, metrics = self.validate(pack_losses_with_val_metrics) # Checkpoint based on the loss and metrics produced during validation AFTER server-side aggregation # NOTE: This assumes that the loss returned in the checkpointing loss @@ -787,6 +785,7 @@ def _validate_or_test( loss_meter: LossMeter, metric_manager: MetricManager, logging_mode: LoggingMode = LoggingMode.VALIDATION, + include_losses_in_metrics: bool = False, ) -> Tuple[float, Dict[str, Scalar]]: """ Evaluate the model on the given validation or test dataset. @@ -795,8 +794,10 @@ def _validate_or_test( loader (DataLoader): The data loader for the dataset (validation or test). loss_meter (LossMeter): The meter to track the losses. metric_manager (MetricManager): The manager to track the metrics. - logging_mode (LoggingMode): The LoggingMode for whether this evaluation is for validation or test. - Default is for validation. + logging_mode (LoggingMode, optional): The LoggingMode for whether this evaluation is for validation or + test. Defaults to LoggingMode.VALIDATION. + include_losses_in_metrics (bool, optional): Whether or not to pack the additional losses into the metrics + dictionary. Defaults to False. Returns: Tuple[float, Dict[str, Scalar]]: The loss and a dictionary of metrics from evaluation. @@ -818,9 +819,23 @@ def _validate_or_test( metrics = metric_manager.compute() self._log_results(loss_dict, metrics, logging_mode=logging_mode) + if include_losses_in_metrics: + self._fold_loss_dict_into_metrics(metrics, loss_dict, logging_mode) + return loss_dict["checkpoint"], metrics - def validate(self) -> Tuple[float, Dict[str, Scalar]]: + def _fold_loss_dict_into_metrics( + self, metrics: Dict[str, Scalar], loss_dict: Dict[str, float], logging_mode: LoggingMode + ) -> None: + # Prefixing the loss value keys with the mode from which they are generated + if logging_mode is LoggingMode.VALIDATION: + metrics.update({f"{MetricPrefix.VAL_PREFIX.value} {key}": loss_val for key, loss_val in loss_dict.items()}) + else: + metrics.update( + {f"{MetricPrefix.TEST_PREFIX.value} {key}": loss_val for key, loss_val in loss_dict.items()} + ) + + def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: """ Validate the current model on the entire validation and potentially an entire test dataset if it has been defined. @@ -829,15 +844,23 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: Tuple[float, Dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation (and test if present). """ - val_loss, val_metrics = self._validate_or_test(self.val_loader, self.val_loss_meter, self.val_metric_manager) + val_loss, val_metrics = self._validate_or_test( + self.val_loader, + self.val_loss_meter, + self.val_metric_manager, + include_losses_in_metrics=include_losses_in_metrics, + ) if self.test_loader: - test_loss, test_metrics = self._validate_or_test( - self.test_loader, self.test_loss_meter, self.test_metric_manager, LoggingMode.TEST + _, test_metrics = self._validate_or_test( + self.test_loader, + self.test_loss_meter, + self.test_metric_manager, + LoggingMode.TEST, + include_losses_in_metrics=include_losses_in_metrics, ) # There will be no clashes due to the naming convention associated with the metric managers if self.num_test_samples is not None: val_metrics[TEST_NUM_EXAMPLES_KEY] = self.num_test_samples - val_metrics[TEST_LOSS_KEY] = test_loss val_metrics.update(test_metrics) return val_loss, val_metrics diff --git a/fl4health/clients/ditto_client.py b/fl4health/clients/ditto_client.py index d39629e29..bda5bfd25 100644 --- a/fl4health/clients/ditto_client.py +++ b/fl4health/clients/ditto_client.py @@ -353,7 +353,7 @@ def compute_training_loss( return TrainingLosses(backward=loss + penalty_loss, additional_losses=additional_losses) - def validate(self) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: """ Validate the current model on the entire validation dataset. @@ -362,7 +362,7 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: """ # Set the global model to evaluate mode self.global_model.eval() - return super().validate() + return super().validate(include_losses_in_metrics=include_losses_in_metrics) def compute_evaluation_loss( self, diff --git a/fl4health/clients/evaluate_client.py b/fl4health/clients/evaluate_client.py index 000084518..795c85194 100644 --- a/fl4health/clients/evaluate_client.py +++ b/fl4health/clients/evaluate_client.py @@ -168,7 +168,7 @@ def validate_on_model( self._handle_logging(losses, metrics, is_global) return losses, metrics - def validate(self) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_loss_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: local_loss: Optional[EvaluationLosses] = None local_metrics: Optional[Dict[str, Scalar]] = None diff --git a/fl4health/clients/fedrep_client.py b/fl4health/clients/fedrep_client.py index 71a3b219e..815c3d7cc 100644 --- a/fl4health/clients/fedrep_client.py +++ b/fl4health/clients/fedrep_client.py @@ -238,7 +238,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict # Check if we should run an evaluation with validation data after fit # (for example, this is used by FedDGGA) if self._should_evaluate_after_fit(evaluate_after_fit): - validation_loss, validation_metrics = self.evaluate_after_fit() + validation_loss, validation_metrics = self.validate() metrics.update(validation_metrics) # We perform a pre-aggregation checkpoint if applicable self._maybe_checkpoint(validation_loss, validation_metrics, CheckpointMode.PRE_AGGREGATION) diff --git a/fl4health/clients/flash_client.py b/fl4health/clients/flash_client.py index cf6ccde19..7064c55bb 100644 --- a/fl4health/clients/flash_client.py +++ b/fl4health/clients/flash_client.py @@ -47,14 +47,16 @@ def __init__( # gamma: Threshold for early stopping based on the change in validation loss. self.gamma: Optional[float] = None - def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool]: - local_epochs, local_steps, current_server_round, evaluate_after_fit = super().process_config(config) + def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool, bool]: + local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics = ( + super().process_config(config) + ) if local_steps is not None: raise ValueError( "Training by steps is not applicable for FLASH clients.\ Please define 'local_epochs' in your config instead" ) - return local_epochs, local_steps, current_server_round, evaluate_after_fit + return local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics def train_by_epochs( self, epochs: int, current_round: Optional[int] = None diff --git a/fl4health/clients/mr_mtl_client.py b/fl4health/clients/mr_mtl_client.py index 7879bcba0..b35ceeaca 100644 --- a/fl4health/clients/mr_mtl_client.py +++ b/fl4health/clients/mr_mtl_client.py @@ -149,7 +149,7 @@ def compute_training_loss( # Use the rest of the training loss computation from the AdaptiveDriftConstraintClient parent return super().compute_training_loss(preds, features, target) - def validate(self) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: """ Validate the current model on the entire validation dataset. @@ -158,4 +158,4 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: """ # ensure that the initial global model is in eval mode assert not self.initial_global_model.training - return super().validate() + return super().validate(include_losses_in_metrics=include_losses_in_metrics) diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index 87a8b47d1..a58b4371c 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -269,7 +269,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler: ) # Determine total number of steps throughout all FL rounds - local_epochs, local_steps, _, _ = self.process_config(config) + local_epochs, local_steps, _, _, _ = self.process_config(config) if local_steps is not None: steps_per_round = local_steps elif local_epochs is not None: diff --git a/fl4health/server/base_server.py b/fl4health/server/base_server.py index 09f81029f..fdfc0a81f 100644 --- a/fl4health/server/base_server.py +++ b/fl4health/server/base_server.py @@ -22,7 +22,7 @@ from fl4health.server.polling import poll_clients from fl4health.strategies.strategy_with_poll import StrategyWithPolling from fl4health.utils.config import narrow_dict_type_and_set_attribute -from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, TestMetricPrefix +from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, MetricPrefix from fl4health.utils.parameter_extraction import get_all_model_parameters from fl4health.utils.random import generate_hash @@ -197,11 +197,9 @@ def _unpack_metrics( for client_proxy, eval_res in results: val_metrics = { - k: v for k, v in eval_res.metrics.items() if not k.startswith(TestMetricPrefix.TEST_PREFIX.value) - } - test_metrics = { - k: v for k, v in eval_res.metrics.items() if k.startswith(TestMetricPrefix.TEST_PREFIX.value) + k: v for k, v in eval_res.metrics.items() if not k.startswith(MetricPrefix.TEST_PREFIX.value) } + test_metrics = {k: v for k, v in eval_res.metrics.items() if k.startswith(MetricPrefix.TEST_PREFIX.value)} if len(test_metrics) > 0: assert TEST_LOSS_KEY in test_metrics and TEST_NUM_EXAMPLES_KEY in test_metrics, ( @@ -245,9 +243,7 @@ def _handle_result_aggregation( for key, value in test_metrics_aggregated.items(): val_metrics_aggregated[key] = value if test_loss_aggregated is not None: - val_metrics_aggregated[f"{TestMetricPrefix.TEST_PREFIX.value} loss - aggregated"] = ( - test_loss_aggregated - ) + val_metrics_aggregated[f"{MetricPrefix.TEST_PREFIX.value} loss - aggregated"] = test_loss_aggregated return val_loss_aggregated, val_metrics_aggregated diff --git a/fl4health/strategies/fedavg_with_adaptive_constraint.py b/fl4health/strategies/fedavg_with_adaptive_constraint.py index 9b1e59030..39f46ac98 100644 --- a/fl4health/strategies/fedavg_with_adaptive_constraint.py +++ b/fl4health/strategies/fedavg_with_adaptive_constraint.py @@ -55,9 +55,6 @@ def __init__( Implementation based on https://arxiv.org/abs/1602.05629. Args: - initial_parameters (Parameters): Initial global model parameters. - init_loss_weight (float): Initial loss weight (mu in FedProx). If adaptivity is false, then this is the - constant weight used for all clients. fraction_fit (float, optional): Fraction of clients used during training. Defaults to 1.0. fraction_evaluate (float, optional): Fraction of clients used during validation. Defaults to 1.0. min_fit_clients (int, optional): _description_. Defaults to 2. @@ -74,10 +71,13 @@ def __init__( Function used to configure server-side central validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. + initial_parameters (Parameters): Initial global model parameters. fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. Defaults to None. evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. Defaults to None. + init_loss_weight (float): Initial loss weight (mu in FedProx). If adaptivity is false, then this is the + constant weight used for all clients. adapt_loss_weight (bool, optional): Determines whether the value of mu is adaptively modified by the server based on aggregated train loss. Defaults to False. loss_weight_delta (float, optional): This is the amount by which the server changes the value of mu diff --git a/fl4health/strategies/feddg_ga_strategy.py b/fl4health/strategies/feddg_ga.py similarity index 87% rename from fl4health/strategies/feddg_ga_strategy.py rename to fl4health/strategies/feddg_ga.py index e511314f6..e791bbb24 100644 --- a/fl4health/strategies/feddg_ga_strategy.py +++ b/fl4health/strategies/feddg_ga.py @@ -3,7 +3,14 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np -from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import ( + EvaluateIns, + MetricsAggregationFn, + NDArrays, + Parameters, + ndarrays_to_parameters, + parameters_to_ndarrays, +) from flwr.common.logger import log from flwr.common.typing import EvaluateRes, FitIns, FitRes, Scalar from flwr.server.client_manager import ClientManager @@ -23,7 +30,7 @@ class FairnessMetricType(Enum): """Defines the basic types for fairness metrics, their default names and their default signals""" ACCURACY = "val - prediction - accuracy" - LOSS = "val - loss" + LOSS = "val - checkpoint" CUSTOM = "custom" @classmethod @@ -109,7 +116,10 @@ def __init__( weight_step_size: float = 0.2, ): """Strategy for the FedDG-GA algorithm (Federated Domain Generalization with - Generalization Adjustment, Zhang et al. 2023). + Generalization Adjustment, Zhang et al. 2023). This strategy assumes (and checks) that the configuration + sent by the server to the clients has the key "evaluate_after_fit" and it is set to True. It also ensures + that the key "pack_losses_with_val_metrics" is present and its value is set to True. These are to facilitate + the exchange of evaluation information needed fro the strategy to work correctly. Args: fraction_fit : float, optional @@ -133,9 +143,9 @@ def __init__( ] Optional function used for validation. Defaults to None. on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. + Function used to configure training. Must be specified for this strategy. Defaults to None. on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. + Function used to configure validation. Must be specified for this strategy. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters, optional @@ -220,18 +230,41 @@ def configure_fit( self.initial_adjustment_weight = 1.0 / len(client_fit_ins) - # Setting self.num_rounds + # Setting self.num_rounds once and doing some sanity checks + assert self.on_fit_config_fn is not None, "on_fit_config_fn must be specified" + config = self.on_fit_config_fn(server_round) + assert "evaluate_after_fit" in config, "evaluate_after_fit must be present in config" + assert config["evaluate_after_fit"] is True, "evaluate_after_fit must be set to True" + + assert "pack_losses_with_val_metrics" in config, "pack_losses_with_val_metrics must be present in config" + assert config["pack_losses_with_val_metrics"] is True, "pack_losses_with_val_metrics must be set to True" + + assert "n_server_rounds" in config, "n_server_rounds must be specified" + assert isinstance(config["n_server_rounds"], int), "n_server_rounds is not an integer" + if self.num_rounds is None: - assert self.on_fit_config_fn is not None, "on_fit_config_fn must be specified" - config = self.on_fit_config_fn(server_round) - assert "evaluate_after_fit" in config, "evaluate_after_fit must be present in config and set to True" - assert config["evaluate_after_fit"] is True, "evaluate_after_fit must be set to True" - assert "n_server_rounds" in config, "n_server_rounds must be specified" - assert isinstance(config["n_server_rounds"], int), "n_server_rounds is not an integer" self.num_rounds = config["n_server_rounds"] + else: + assert config["n_server_rounds"] == self.num_rounds, "n_server_rounds has changed from the original value" return client_fit_ins + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + assert isinstance( + client_manager, FixedSamplingClientManager + ), f"Client manager is not of type FixedSamplingClientManager: {type(client_manager)}" + + client_evaluate_ins = super().configure_evaluate(server_round, parameters, client_manager) + + assert self.on_evaluate_config_fn is not None, "on_fit_config_fn must be specified" + config = self.on_evaluate_config_fn(server_round) + assert "pack_losses_with_val_metrics" in config, "pack_losses_with_val_metrics must be present in config" + assert config["pack_losses_with_val_metrics"] is True, "pack_losses_with_val_metrics must be set to True" + + return client_evaluate_ins + def aggregate_fit( self, server_round: int, @@ -300,10 +333,9 @@ def aggregate_evaluate( self.evaluation_metrics = {} for client_proxy, eval_res in results: cid = client_proxy.cid + # make sure that the metrics has the desired loss key + assert FairnessMetricType.LOSS.value in eval_res.metrics self.evaluation_metrics[cid] = eval_res.metrics - # adding the loss to the metrics - val_loss_key = FairnessMetricType.LOSS.value - self.evaluation_metrics[cid][val_loss_key] = eval_res.loss # Updating the weights at the end of the training round cids = [client_proxy.cid for client_proxy, _ in results] diff --git a/fl4health/strategies/feddg_ga_with_adaptive_constraint.py b/fl4health/strategies/feddg_ga_with_adaptive_constraint.py new file mode 100644 index 000000000..e99281e50 --- /dev/null +++ b/fl4health/strategies/feddg_ga_with_adaptive_constraint.py @@ -0,0 +1,268 @@ +from logging import INFO, WARNING +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common.logger import log +from flwr.common.typing import FitRes, Scalar +from flwr.server.client_proxy import ClientProxy + +from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint +from fl4health.strategies.aggregate_utils import aggregate_losses +from fl4health.strategies.feddg_ga import FairnessMetric, FairnessMetricType, FedDgGaStrategy + + +class FedDgGaAdaptiveConstraint(FedDgGaStrategy): + def __init__( + self, + *, + fraction_fit: float = 1.0, + fraction_evaluate: float = 1.0, + min_fit_clients: int = 2, + min_evaluate_clients: int = 2, + min_available_clients: int = 2, + evaluate_fn: Optional[ + Callable[ + [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]], + ] + ] = None, + on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + accept_failures: bool = True, + initial_parameters: Parameters, + fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + initial_loss_weight: float = 1.0, + adapt_loss_weight: bool = False, + loss_weight_delta: float = 0.1, + loss_weight_patience: int = 5, + weighted_train_losses: bool = False, + fairness_metric: Optional[FairnessMetric] = None, + weight_step_size: float = 0.2, + ): + """ + Strategy for the FedDG-GA algorithm (Federated Domain Generalization with Generalization Adjustment, + Zhang et al. 2023) combined with the Adaptive Strategy for Auxiliary constraints like FedProx. See + documentation on FedAvgWithAdaptiveConstraint for more information. + + NOTE: Initial parameters are NOT optional. They must be passed for this strategy. + + Args: + fraction_fit : float, optional + Fraction of clients used during training. In case `min_fit_clients` + is larger than `fraction_fit * available_clients`, `min_fit_clients` + will still be sampled. Defaults to 1.0. + fraction_evaluate : float, optional + Fraction of clients used during validation. In case `min_evaluate_clients` + is larger than `fraction_evaluate * available_clients`, `min_evaluate_clients` + will still be sampled. Defaults to 1.0. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : + Optional[ + Callable[[int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]]] + ] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters, optional + Initial global model parameters. + fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] + Metrics aggregation function, optional. + adapt_loss_weight (bool, optional): Determines whether the value of mu is adaptively modified by + the server based on aggregated train loss. Defaults to False. + loss_weight_delta (float, optional): This is the amount by which the server changes the value of mu + based on the modification criteria. Only applicable if adaptivity is on. Defaults to 0.1. + loss_weight_patience (int, optional): This is the number of rounds a server must see decreasing + aggregated train loss before reducing the value of mu. Only applicable if adaptivity is on. + Defaults to 5. + weighted_train_losses (bool, optional): Determines whether the training losses from the clients should be + aggregated using a weighted or unweighted average. These aggregated losses are used to adjust the + proximal weight in the adaptive setting. Defaults to False. + fairness_metric : FairnessMetric, optional. + The metric to evaluate the local model of each client against the global model in order to + determine their adjustment weight for aggregation. Can be set to any default metric in + FairnessMetricType or set to use a custom metric. Optional, default is + FairnessMetric(FairnessMetricType.LOSS). + weight_step_size : float + The step size to determine the magnitude of change for the adjustment weight. It has to be + 0 < weight_step_size < 1. Optional, default is 0.2. + """ + if fraction_fit != 1.0 or fraction_evaluate != 1.0: + log( + WARNING, + "fraction_fit or fraction_evaluate are not 1.0. The behaviour of FedDG-GA is unknown in those cases.", + ) + + self.loss_weight = initial_loss_weight + self.adapt_loss_weight = adapt_loss_weight + + if self.adapt_loss_weight: + self.loss_weight_delta = loss_weight_delta + self.loss_weight_patience = loss_weight_patience + self.loss_weight_patience_counter: int = 0 + + self.previous_loss = float("inf") + + self.server_model_weights = parameters_to_ndarrays(initial_parameters) + initial_parameters.tensors.extend(ndarrays_to_parameters([np.array(initial_loss_weight)]).tensors) + + super().__init__( + fraction_fit=fraction_fit, + fraction_evaluate=fraction_evaluate, + min_fit_clients=min_fit_clients, + min_evaluate_clients=min_evaluate_clients, + min_available_clients=min_available_clients, + evaluate_fn=evaluate_fn, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + accept_failures=accept_failures, + initial_parameters=initial_parameters, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + ) + + if fairness_metric is None: + self.fairness_metric = FairnessMetric(FairnessMetricType.LOSS) + else: + self.fairness_metric = fairness_metric + + self.weight_step_size = weight_step_size + assert 0 < self.weight_step_size < 1, f"weight_step_size has to be between 0 and 1 ({self.weight_step_size})" + + self.train_metrics: Dict[str, Dict[str, Scalar]] = {} + self.evaluation_metrics: Dict[str, Dict[str, Scalar]] = {} + self.num_rounds: Optional[int] = None + self.initial_adjustment_weight: Optional[float] = None + self.adjustment_weights: Dict[str, float] = {} + self.parameter_packer = ParameterPackerAdaptiveConstraint() + self.weighted_train_losses = weighted_train_losses + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """ + Aggregate fit results by weighing them against the adjustment weights and then summing them. + + Collects the fit metrics that will be used to change the adjustment weights for the next round. + + If applicable, determine whether the constraint weight should be updated based on the aggregated loss + seen on the clients. + + Args: + server_round: (int) the current server round. + results: (List[Tuple[ClientProxy, FitRes]]) The clients' fit results. + failures: (List[Union[Tuple[ClientProxy, FitRes], BaseException]]) the clients' fit failures. + + Returns: + (Tuple[Optional[Parameters], Dict[str, Scalar]]) A tuple containing the aggregated parameters + and the aggregated fit metrics. For adaptive constraints, the server also packs a constraint weight + to be sent to the clients. This is sent even if adaptive constraint weights are turned off and + the value simply remains constant. + """ + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Convert results with packed params of model weights and training loss. The results list is modified in-place + # to only contain model parameters for use in the Fed-DGGA calculations and aggregation + train_losses_and_counts = self._unpack_weights_and_losses(results) + + # Aggregate train loss + train_losses_aggregated = aggregate_losses(train_losses_and_counts, self.weighted_train_losses) + self._maybe_update_constraint_weight_param(train_losses_aggregated) + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") + + self.train_metrics = {} + for client_proxy, fit_res in results: + self.train_metrics[client_proxy.cid] = fit_res.metrics + + weights_aggregated = self.weight_and_aggregate_results(results) + + parameters = self.parameter_packer.pack_parameters(weights_aggregated, self.loss_weight) + return ndarrays_to_parameters(parameters), metrics_aggregated + + def _unpack_weights_and_losses(self, results: List[Tuple[ClientProxy, FitRes]]) -> List[Tuple[int, float]]: + """ + This function takes results returned from a fit round from each of the participating clients and unpacks the + information into the appropriate objects. The parameters contained in the FitRes object are unpacked to + separate the model weights from the training losses. The model weights are reinserted into the parameters + of the FitRes objects and the losses (along with sample counts) are placed in a list and returned + + NOTE: The results that are passed to this function are MODIFIED IN-PLACE + + Args: + results (List[Tuple[ClientProxy, FitRes]]): The results produced in a fitting round by each of the clients + these the FitRes object contains both model weights and training losses which need to be processed. + + Returns: + List[Tuple[int, float]]: A list of the training losses produced by client training + """ + train_losses_and_counts: List[Tuple[int, float]] = [] + for _, fit_res in results: + sample_count = fit_res.num_examples + updated_weights, train_loss = self.parameter_packer.unpack_parameters( + parameters_to_ndarrays(fit_res.parameters) + ) + # Modify the parameters in-place to just be the model weights. + fit_res.parameters = ndarrays_to_parameters(updated_weights) + train_losses_and_counts.append((sample_count, train_loss)) + + return train_losses_and_counts + + def _maybe_update_constraint_weight_param(self, loss: float) -> None: + """ + Update constraint weight parameter if adaptive_loss_weight is set to True. Regardless of whether adaptivity + is turned on at this time, the previous loss seen by the server is updated. + + Args: + loss (float): This is the loss to which we compare the previous loss seen by the server. For Adaptive + Constraint clients this should be the aggregated training loss seen by each client participating in + training. + NOTE: For adaptive constraint losses, including FedProx, this loss is exchanged (along with the weights) + by each client and is the VANILLA loss that does not include the additional penalty losses. + """ + + if self.adapt_loss_weight: + if loss <= self.previous_loss: + self.loss_weight_patience_counter += 1 + if self.loss_weight_patience_counter == self.loss_weight_patience: + self.loss_weight -= self.loss_weight_delta + self.loss_weight = max(0.0, self.loss_weight) + self.loss_weight_patience_counter = 0 + log(INFO, f"Aggregate training loss has dropped {self.loss_weight_patience} rounds in a row") + log(INFO, f"Constraint weight is decreased to {self.loss_weight}") + else: + self.loss_weight += self.loss_weight_delta + self.loss_weight_patience_counter = 0 + log( + INFO, + f"Aggregate training loss increased this round: Current loss {loss}, " + f"Previous loss: {self.previous_loss}", + ) + log(INFO, f"Constraint weight is increased by {self.loss_weight_delta} to {self.loss_weight}") + self.previous_loss = loss diff --git a/fl4health/utils/metrics.py b/fl4health/utils/metrics.py index 642dd8c1c..77cab3567 100644 --- a/fl4health/utils/metrics.py +++ b/fl4health/utils/metrics.py @@ -12,13 +12,13 @@ from fl4health.utils.typing import TorchPredType, TorchTargetType, TorchTransformFunction -class TestMetricPrefix(Enum): - __test__ = False +class MetricPrefix(Enum): TEST_PREFIX = "test -" + VAL_PREFIX = "val -" -TEST_NUM_EXAMPLES_KEY = f"{TestMetricPrefix.TEST_PREFIX.value} num_examples" -TEST_LOSS_KEY = f"{TestMetricPrefix.TEST_PREFIX.value} loss" +TEST_NUM_EXAMPLES_KEY = f"{MetricPrefix.TEST_PREFIX.value} num_examples" +TEST_LOSS_KEY = f"{MetricPrefix.TEST_PREFIX.value} loss" class Metric(ABC): diff --git a/tests/clients/test_basic_client.py b/tests/clients/test_basic_client.py index ffedb6805..94442a410 100644 --- a/tests/clients/test_basic_client.py +++ b/tests/clients/test_basic_client.py @@ -65,12 +65,21 @@ def test_metrics_reporter_evaluate() -> None: test_metrics_final = { "test_metric": 1234, "testing_metric": 1234, - "test - loss": 123.123, + "val - checkpoint": 123.123, + "test - checkpoint": 123.123, "test - num_examples": 0, } - fl_client = MockBasicClient(loss=test_loss, metrics=test_metrics, test_set_metrics=test_metrics_testing) - fl_client.evaluate([], {"current_server_round": test_current_server_round, "local_epochs": 0}) + fl_client = MockBasicClient( + loss=test_loss, + loss_dict={"checkpoint": test_loss}, + metrics=test_metrics, + test_set_metrics=test_metrics_testing, + ) + fl_client.evaluate( + [], + {"current_server_round": test_current_server_round, "local_epochs": 0, "pack_losses_with_val_metrics": True}, + ) assert fl_client.metrics_reporter.metrics == { "type": "client", @@ -156,12 +165,11 @@ def __init__( self._validate_or_test.side_effect = self.mock_validate_or_test def mock_validate_or_test( # type: ignore - self, - loader, - loss_meter, - metric_manager, - logging_mode=LoggingMode.VALIDATION, + self, loader, loss_meter, metric_manager, logging_mode=LoggingMode.VALIDATION, include_losses_in_metrics=False ): + if include_losses_in_metrics: + assert self.mock_loss_dict is not None and self.mock_metrics is not None + self._fold_loss_dict_into_metrics(self.mock_metrics, self.mock_loss_dict, logging_mode) if logging_mode == LoggingMode.VALIDATION: return self.mock_loss, self.mock_metrics else: diff --git a/tests/server/test_base_server.py b/tests/server/test_base_server.py index 8dd50c229..8f834c83c 100644 --- a/tests/server/test_base_server.py +++ b/tests/server/test_base_server.py @@ -19,7 +19,7 @@ from fl4health.server.base_server import FlServer, FlServerWithCheckpointing from fl4health.strategies.basic_fedavg import BasicFedAvg from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn -from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, TestMetricPrefix +from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, MetricPrefix from tests.test_utils.custom_client_proxy import CustomClientProxy from tests.test_utils.models_for_test import LinearTransform @@ -152,7 +152,7 @@ def test_unpack_metrics() -> None: "val - prediction - accuracy": 0.9, TEST_LOSS_KEY: 0.8, TEST_NUM_EXAMPLES_KEY: 5, - f"{TestMetricPrefix.TEST_PREFIX.value} accuracy": 0.85, + f"{MetricPrefix.TEST_PREFIX.value} accuracy": 0.85, }, ) @@ -167,7 +167,7 @@ def test_unpack_metrics() -> None: # Check the test results assert len(test_results) == 1 - assert test_results[0][1].metrics[f"{TestMetricPrefix.TEST_PREFIX.value} accuracy"] == 0.85 + assert test_results[0][1].metrics[f"{MetricPrefix.TEST_PREFIX.value} accuracy"] == 0.85 assert test_results[0][1].loss == 0.8 @@ -185,7 +185,7 @@ def test_handle_result_aggregation() -> None: "val - prediction - accuracy": 0.9, TEST_LOSS_KEY: 0.8, TEST_NUM_EXAMPLES_KEY: 5, - f"{TestMetricPrefix.TEST_PREFIX.value} accuracy": 0.85, + f"{MetricPrefix.TEST_PREFIX.value} accuracy": 0.85, }, ) client_proxy2 = CustomClientProxy("2") @@ -197,7 +197,7 @@ def test_handle_result_aggregation() -> None: "val - prediction - accuracy": 0.8, TEST_LOSS_KEY: 1.6, TEST_NUM_EXAMPLES_KEY: 10, - f"{TestMetricPrefix.TEST_PREFIX.value} accuracy": 0.75, + f"{MetricPrefix.TEST_PREFIX.value} accuracy": 0.75, }, ) @@ -205,17 +205,17 @@ def test_handle_result_aggregation() -> None: failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] server_round = 1 - val_loss_aggregated, val_metrics_aggregated = fl_server._handle_result_aggregation(server_round, results, failures) + _, val_metrics_aggregated = fl_server._handle_result_aggregation(server_round, results, failures) # Check the aggregated validation metrics assert "val - prediction - accuracy" in val_metrics_aggregated assert val_metrics_aggregated["val - prediction - accuracy"] == pytest.approx(0.8333, rel=1e-3) # Check the aggregated test metrics - assert f"{TestMetricPrefix.TEST_PREFIX.value} accuracy" in val_metrics_aggregated - assert val_metrics_aggregated[f"{TestMetricPrefix.TEST_PREFIX.value} accuracy"] == pytest.approx(0.7833, rel=1e-3) - assert f"{TestMetricPrefix.TEST_PREFIX.value} loss - aggregated" in val_metrics_aggregated - assert val_metrics_aggregated[f"{TestMetricPrefix.TEST_PREFIX.value} loss - aggregated"] == pytest.approx( + assert f"{MetricPrefix.TEST_PREFIX.value} accuracy" in val_metrics_aggregated + assert val_metrics_aggregated[f"{MetricPrefix.TEST_PREFIX.value} accuracy"] == pytest.approx(0.7833, rel=1e-3) + assert f"{MetricPrefix.TEST_PREFIX.value} loss - aggregated" in val_metrics_aggregated + assert val_metrics_aggregated[f"{MetricPrefix.TEST_PREFIX.value} loss - aggregated"] == pytest.approx( 1.333, rel=1e-3 ) diff --git a/tests/smoke_tests/feddg_ga_client_metrics.json b/tests/smoke_tests/feddg_ga_client_metrics.json index 0be323091..fa12841fd 100644 --- a/tests/smoke_tests/feddg_ga_client_metrics.json +++ b/tests/smoke_tests/feddg_ga_client_metrics.json @@ -9,7 +9,7 @@ "val - personal - accuracy": 0.6254, "val - global - accuracy": 0.6757, "val - local - accuracy": 0.5723, - "val - loss": 1.4845 + "val - checkpoint": 1.4845 }, "loss_dict": { "global": 1.2897, @@ -31,7 +31,7 @@ "val - personal - accuracy": {"target_value": 0.6778, "custom_tolerance": 0.005}, "val - global - accuracy": 0.76, "val - local - accuracy": {"target_value": 0.5066, "custom_tolerance": 0.005}, - "val - loss": {"target_value": 0.9618, "custom_tolerance": 0.005} + "val - checkpoint": {"target_value": 0.9618, "custom_tolerance": 0.005} }, "loss_dict": { "global": 0.7053, @@ -53,7 +53,7 @@ "val - personal - accuracy": {"target_value": 0.739, "custom_tolerance": 0.005}, "val - global - accuracy": 0.78, "val - local - accuracy": {"target_value": 0.5602, "custom_tolerance": 0.005}, - "val - loss": {"target_value": 0.8043, "custom_tolerance": 0.005} + "val - checkpoint": {"target_value": 0.8043, "custom_tolerance": 0.005} }, "loss_dict": { "global": 0.4757, diff --git a/tests/strategies/test_feddg_ga_strategy.py b/tests/strategies/test_feddg_ga_strategy.py index 3416caf60..8e3a7d598 100644 --- a/tests/strategies/test_feddg_ga_strategy.py +++ b/tests/strategies/test_feddg_ga_strategy.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple from unittest.mock import Mock import numpy as np @@ -9,11 +9,11 @@ from pytest import approx, raises from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager -from fl4health.strategies.feddg_ga_strategy import FairnessMetricType, FedDgGaStrategy +from fl4health.strategies.feddg_ga import FairnessMetricType, FedDgGaStrategy from tests.test_utils.custom_client_proxy import CustomClientProxy -def test_configure_fit_success() -> None: +def test_configure_fit_and_evaluate_success() -> None: fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) test_n_server_rounds = 3 @@ -21,9 +21,16 @@ def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": test_n_server_rounds, "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn) + def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": test_n_server_rounds, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn) assert strategy.num_rounds is None try: @@ -50,6 +57,7 @@ def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, } strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn) @@ -61,6 +69,7 @@ def on_fit_config_fn_1(server_round: int) -> Dict[str, Scalar]: return { "foo": 123, "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, } strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_1) @@ -74,6 +83,7 @@ def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": 1.1, "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, } strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_2) @@ -86,6 +96,7 @@ def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: def on_fit_config_fn_3(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": 2, + "pack_losses_with_val_metrics": True, } strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_3) @@ -97,12 +108,78 @@ def on_fit_config_fn_4(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": False, + "pack_losses_with_val_metrics": True, } strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_4) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + # Fails with pack_losses_with_val_metrics not being there + def on_fit_config_fn_5(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + } + + strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_5) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being True + def on_fit_config_fn_6(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": False, + } + + strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_6) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + +def test_configure_evaluate_fail() -> None: + fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) + simple_client_manager = _apply_mocks_to_client_manager(SimpleClientManager()) + + # Fails with no evaluate fit + strategy = FedDgGaStrategy() + with raises(AssertionError): + strategy.configure_evaluate(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with bad client manager type + def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaStrategy(on_evaluate_config_fn=on_evaluate_config_fn) + with raises(AssertionError): + strategy.configure_evaluate(1, Parameters([], ""), simple_client_manager) + + # Fail with no pack_losses_with_val_metrics + def on_evaluate_config_fn_1(server_round: int) -> Dict[str, Scalar]: + return { + "foo": 123, + } + + strategy = FedDgGaStrategy(on_evaluate_config_fn=on_evaluate_config_fn_1) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being True + def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 1.1, + "pack_losses_with_val_metrics": False, + } + + strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_2) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + def test_aggregate_fit_and_aggregate_evaluate() -> None: test_fit_results, test_eval_results = _make_test_data() @@ -112,7 +189,6 @@ def test_aggregate_fit_and_aggregate_evaluate() -> None: test_fit_metrics_2 = test_fit_results[1][1].metrics test_eval_metrics_1 = test_eval_results[0][1].metrics test_eval_metrics_2 = test_eval_results[1][1].metrics - test_val_loss_key = FairnessMetricType.LOSS.value test_initial_adjustment_weight = 1.0 / 3.0 strategy = FedDgGaStrategy() @@ -135,16 +211,17 @@ def test_aggregate_fit_and_aggregate_evaluate() -> None: assert parameters_array == [approx(1.0, abs=0.0005), approx(1.0666, abs=0.0005)] # test evaluate fit - _, _ = strategy.aggregate_evaluate(2, deepcopy(test_eval_results), []) + loss_aggregated, _ = strategy.aggregate_evaluate(2, deepcopy(test_eval_results), []) assert strategy.evaluation_metrics == { - test_cid_1: {**test_eval_metrics_1, test_val_loss_key: test_eval_results[0][1].loss}, - test_cid_2: {**test_eval_metrics_2, test_val_loss_key: test_eval_results[1][1].loss}, + test_cid_1: {**test_eval_metrics_1}, + test_cid_2: {**test_eval_metrics_2}, } assert strategy.adjustment_weights == { test_cid_1: approx(0.2999, abs=0.0005), test_cid_2: approx(0.7000, abs=0.0005), } + assert approx(loss_aggregated, abs=1e-6) == 1.7 def test_weight_and_aggregate_results_with_default_weights() -> None: @@ -270,10 +347,10 @@ def _apply_mocks_to_client_manager(client_manager: ClientManager) -> ClientManag def _make_test_data() -> Tuple[List[Tuple[ClientProxy, FitRes]], List[Tuple[ClientProxy, EvaluateRes]]]: test_val_loss_key = FairnessMetricType.LOSS.value - test_fit_metrics_1: Dict[str, Union[bool, bytes, float, int, str]] = {test_val_loss_key: 1.0} - test_fit_metrics_2: Dict[str, Union[bool, bytes, float, int, str]] = {test_val_loss_key: 2.0} - test_eval_metrics_1: Dict[str, Union[bool, bytes, float, int, str]] = {"metric-1": 1.0} - test_eval_metrics_2: Dict[str, Union[bool, bytes, float, int, str]] = {"metric-2": 2.0} + test_fit_metrics_1: Dict[str, Scalar] = {test_val_loss_key: 1.0} + test_fit_metrics_2: Dict[str, Scalar] = {test_val_loss_key: 2.0} + test_eval_metrics_1: Dict[str, Scalar] = {"metric-1": 1.0, test_val_loss_key: 1.2} + test_eval_metrics_2: Dict[str, Scalar] = {"metric-2": 2.0, test_val_loss_key: 2.2} test_parameters_1 = ndarrays_to_parameters([np.array([1.0, 1.1])]) test_parameters_2 = ndarrays_to_parameters([np.array([2.0, 2.1])]) test_fit_results = [ diff --git a/tests/strategies/test_feddg_ga_with_adapt_constraint.py b/tests/strategies/test_feddg_ga_with_adapt_constraint.py new file mode 100644 index 000000000..e7300d963 --- /dev/null +++ b/tests/strategies/test_feddg_ga_with_adapt_constraint.py @@ -0,0 +1,411 @@ +from copy import deepcopy +from typing import Dict, List, Tuple +from unittest.mock import Mock + +import numpy as np +from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common.typing import Code, EvaluateRes, FitRes, Parameters, Scalar, Status +from flwr.server.client_manager import ClientManager, ClientProxy, SimpleClientManager +from pytest import approx, raises + +from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager +from fl4health.strategies.feddg_ga import FairnessMetricType +from fl4health.strategies.feddg_ga_with_adaptive_constraint import FedDgGaAdaptiveConstraint +from tests.test_utils.custom_client_proxy import CustomClientProxy + +INITIAL_PARAMETERS = ndarrays_to_parameters([np.array([0.0, 0.0])]) + + +def test_configure_fit_and_evaluate_success() -> None: + fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) + test_n_server_rounds = 3 + + def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": test_n_server_rounds, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, + } + + def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": test_n_server_rounds, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint( + initial_parameters=INITIAL_PARAMETERS, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + ) + assert strategy.num_rounds is None + + try: + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + except Exception as e: + assert False, f"initialize_parameters threw an exception: {e}" + + assert strategy.num_rounds == test_n_server_rounds + assert strategy.initial_adjustment_weight == 1.0 / fixed_sampling_client_manager.num_available() + fixed_sampling_client_manager.reset_sample.assert_called_once() # type: ignore + + +def test_configure_fit_fail() -> None: + fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) + simple_client_manager = _apply_mocks_to_client_manager(SimpleClientManager()) + + # Fails with no configure fit + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with bad client manager type + def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), simple_client_manager) + + # Fail with no n_server_rounds + def on_fit_config_fn_1(server_round: int) -> Dict[str, Scalar]: + return { + "foo": 123, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_1) + assert strategy.num_rounds is None + + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with n_server_rounds not being an integer + def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 1.1, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_2) + assert strategy.num_rounds is None + + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with evaluate_after_fit not being set + def on_fit_config_fn_3(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_3) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with evaluate_after_fit not being True + def on_fit_config_fn_4(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": False, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_4) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being there + def on_fit_config_fn_5(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_5) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being True + def on_fit_config_fn_6(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": False, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_6) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + +def test_configure_evaluate_fail() -> None: + fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) + simple_client_manager = _apply_mocks_to_client_manager(SimpleClientManager()) + + # Fails with no evaluate fit + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + with raises(AssertionError): + strategy.configure_evaluate(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with bad client manager type + def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint( + initial_parameters=INITIAL_PARAMETERS, on_evaluate_config_fn=on_evaluate_config_fn + ) + with raises(AssertionError): + strategy.configure_evaluate(1, Parameters([], ""), simple_client_manager) + + # Fail with no pack_losses_with_val_metrics + def on_evaluate_config_fn_1(server_round: int) -> Dict[str, Scalar]: + return { + "foo": 123, + } + + strategy = FedDgGaAdaptiveConstraint( + initial_parameters=INITIAL_PARAMETERS, on_evaluate_config_fn=on_evaluate_config_fn_1 + ) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being True + def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 1.1, + "pack_losses_with_val_metrics": False, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_2) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + +def test_aggregate_fit_and_aggregate_evaluate() -> None: + test_fit_results, test_eval_results = _make_test_data() + test_cid_1 = test_fit_results[0][0].cid + test_cid_2 = test_fit_results[1][0].cid + test_fit_metrics_1 = test_fit_results[0][1].metrics + test_fit_metrics_2 = test_fit_results[1][1].metrics + test_eval_metrics_1 = test_eval_results[0][1].metrics + test_eval_metrics_2 = test_eval_results[1][1].metrics + test_initial_adjustment_weight = 1.0 / 3.0 + + strategy = FedDgGaAdaptiveConstraint( + initial_parameters=INITIAL_PARAMETERS, initial_loss_weight=1.0, adapt_loss_weight=True, loss_weight_patience=1 + ) + strategy.num_rounds = 3 + strategy.initial_adjustment_weight = test_initial_adjustment_weight + + # test aggregate fit + parameters_aggregated, _ = strategy.aggregate_fit(2, deepcopy(test_fit_results), []) + + # make sure the the loss has been aggregated and stored and the loss weight adjusted + assert strategy.previous_loss == 2.0 + assert strategy.loss_weight == 0.9 + + assert strategy.train_metrics == { + test_cid_1: test_fit_metrics_1, + test_cid_2: test_fit_metrics_2, + } + assert strategy.adjustment_weights == { + test_cid_1: test_initial_adjustment_weight, + test_cid_2: test_initial_adjustment_weight, + } + assert parameters_aggregated is not None + parameters_array = parameters_to_ndarrays(parameters_aggregated)[0].tolist() + assert parameters_array == [approx(1.0, abs=0.0005), approx(1.0666, abs=0.0005)] + + # test evaluate fit + loss_aggregated, _ = strategy.aggregate_evaluate(2, deepcopy(test_eval_results), []) + + assert strategy.evaluation_metrics == { + test_cid_1: {**test_eval_metrics_1}, + test_cid_2: {**test_eval_metrics_2}, + } + assert strategy.adjustment_weights == { + test_cid_1: approx(0.2999, abs=0.0005), + test_cid_2: approx(0.7000, abs=0.0005), + } + assert approx(loss_aggregated, abs=1e-6) == 1.7 + + +def test_weight_and_aggregate_results_with_default_weights() -> None: + test_fit_results, _ = _make_test_data() + test_cid_1 = test_fit_results[0][0].cid + test_cid_2 = test_fit_results[1][0].cid + test_initial_adjustment_weight = 1.0 / 3.0 + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + strategy.initial_adjustment_weight = test_initial_adjustment_weight + strategy._unpack_weights_and_losses(test_fit_results) + aggregated_results = strategy.weight_and_aggregate_results(test_fit_results) + + assert strategy.adjustment_weights == { + test_cid_1: test_initial_adjustment_weight, + test_cid_2: test_initial_adjustment_weight, + } + assert aggregated_results[0].tolist() == [approx(1.0, abs=0.0005), approx(1.0666, abs=0.0005)] + + +def test_weight_and_aggregate_results_with_existing_weights() -> None: + test_fit_results, _ = _make_test_data() + test_cid_1 = test_fit_results[0][0].cid + test_cid_2 = test_fit_results[1][0].cid + test_adjustment_weights = {test_cid_1: 0.21, test_cid_2: 0.76} + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + strategy.adjustment_weights = deepcopy(test_adjustment_weights) + strategy._unpack_weights_and_losses(test_fit_results) + aggregated_results = strategy.weight_and_aggregate_results(test_fit_results) + + assert strategy.adjustment_weights == test_adjustment_weights + assert aggregated_results[0].tolist() == [approx(1.73, abs=0.0005), approx(1.8270, abs=0.0005)] + + +def test_update_weights_by_ga() -> None: + test_cids = ["1", "2"] + test_val_loss_key = FairnessMetricType.LOSS.value + test_initial_adjustment_weight = 1.0 / 3.0 + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + strategy.num_rounds = 3 + strategy.initial_adjustment_weight = test_initial_adjustment_weight + strategy.train_metrics = { + test_cids[0]: {test_val_loss_key: 0.5467}, + test_cids[1]: {test_val_loss_key: 0.5432}, + } + strategy.evaluation_metrics = { + test_cids[0]: {test_val_loss_key: 0.3556}, + test_cids[1]: {test_val_loss_key: 0.7654}, + } + strategy.adjustment_weights = { + test_cids[0]: test_initial_adjustment_weight, + test_cids[1]: test_initial_adjustment_weight, + } + + strategy.update_weights_by_ga(2, test_cids) + + assert strategy.adjustment_weights == { + test_cids[0]: approx(0.2999, abs=0.0005), + test_cids[1]: approx(0.7000, abs=0.0005), + } + + +def test_update_weights_by_ga_with_same_metrics() -> None: + test_cids = ["1", "2"] + test_val_loss_key = FairnessMetricType.LOSS.value + test_initial_adjustment_weight = 1.0 / 3.0 + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + strategy.num_rounds = 3 + strategy.initial_adjustment_weight = test_initial_adjustment_weight + strategy.train_metrics = { + test_cids[0]: {test_val_loss_key: 0.5467}, + test_cids[1]: {test_val_loss_key: 0.5432}, + } + strategy.evaluation_metrics = { + test_cids[0]: {test_val_loss_key: 0.5467}, + test_cids[1]: {test_val_loss_key: 0.5432}, + } + strategy.adjustment_weights = { + test_cids[0]: test_initial_adjustment_weight, + test_cids[1]: test_initial_adjustment_weight, + } + + strategy.update_weights_by_ga(2, test_cids) + + assert strategy.adjustment_weights == {test_cids[0]: 0.5, test_cids[1]: 0.5} + + +def test_get_current_weight_step_size() -> None: + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + + with raises(AssertionError): + strategy.get_current_weight_step_size(2) + + strategy.num_rounds = 3 + result_step_size = strategy.get_current_weight_step_size(1) + assert result_step_size == approx(0.2000, abs=0.0005) + result_step_size = strategy.get_current_weight_step_size(2) + assert result_step_size == approx(0.1333, abs=0.0005) + result_step_size = strategy.get_current_weight_step_size(3) + assert result_step_size == approx(0.0666, abs=0.0005) + + strategy.num_rounds = 10 + result_step_size = strategy.get_current_weight_step_size(6) + assert result_step_size == approx(0.1000, abs=0.0005) + + strategy.num_rounds = 10 + strategy.weight_step_size = 0.5 + result_step_size = strategy.get_current_weight_step_size(6) + assert result_step_size == approx(0.2500, abs=0.0005) + + +def test_unpack_weights_and_losses() -> None: + test_fit_results, _ = _make_test_data() + # make sure the results are of length 2 (one for the weights, one for the loss) + assert len(test_fit_results[0][1].parameters.tensors) == 2 + assert len(test_fit_results[1][1].parameters.tensors) == 2 + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + train_losses_and_counts = strategy._unpack_weights_and_losses(test_fit_results) + + # Assert that the fit results have been modified in place and properly + assert len(test_fit_results) == 2 + test_ndarrays_1 = parameters_to_ndarrays(test_fit_results[0][1].parameters) + test_ndarrays_2 = parameters_to_ndarrays(test_fit_results[1][1].parameters) + target_ndarray_1 = np.array([1.0, 1.1]) + target_ndarray_2 = np.array([2.0, 2.1]) + # length should be 1, since we've unpacked the loss arrays + assert len(test_ndarrays_1) == 1 + assert len(test_ndarrays_2) == 1 + + assert np.allclose(test_ndarrays_1[0], target_ndarray_1, rtol=0.0, atol=1e-6) + assert np.allclose(test_ndarrays_2[0], target_ndarray_2, rtol=0.0, atol=1e-6) + + # Make sure that the losses have properly been extracted and stored. + assert train_losses_and_counts[0][1] == 1.5 + assert train_losses_and_counts[1][1] == 2.5 + + +def _apply_mocks_to_client_manager(client_manager: ClientManager) -> ClientManager: + client_proxy_1 = CustomClientProxy("1") + client_proxy_2 = CustomClientProxy("2") + client_manager.register(client_proxy_1) + client_manager.register(client_proxy_2) + client_manager.sample = Mock() # type: ignore + client_manager.sample.return_value = [client_proxy_1, client_proxy_2] + client_manager.reset_sample = Mock() # type: ignore + return client_manager + + +def _make_test_data() -> Tuple[List[Tuple[ClientProxy, FitRes]], List[Tuple[ClientProxy, EvaluateRes]]]: + test_val_loss_key = FairnessMetricType.LOSS.value + test_fit_metrics_1: Dict[str, Scalar] = {test_val_loss_key: 1.0} + test_fit_metrics_2: Dict[str, Scalar] = {test_val_loss_key: 2.0} + test_eval_metrics_1: Dict[str, Scalar] = {"metric-1": 1.0, test_val_loss_key: 1.2} + test_eval_metrics_2: Dict[str, Scalar] = {"metric-2": 2.0, test_val_loss_key: 2.2} + test_parameters_1 = ndarrays_to_parameters([np.array([1.0, 1.1]), np.array(1.5)]) + test_parameters_2 = ndarrays_to_parameters([np.array([2.0, 2.1]), np.array(2.5)]) + test_fit_results = [ + (CustomClientProxy("1"), FitRes(Status(Code.OK, ""), test_parameters_1, 1, test_fit_metrics_1)), + (CustomClientProxy("2"), FitRes(Status(Code.OK, ""), test_parameters_2, 1, test_fit_metrics_2)), + ] + test_evaluate_results = [ + (CustomClientProxy("1"), EvaluateRes(Status(Code.OK, ""), 1.2, 1, test_eval_metrics_1)), + (CustomClientProxy("2"), EvaluateRes(Status(Code.OK, ""), 2.2, 1, test_eval_metrics_2)), + ] + + return test_fit_results, test_evaluate_results # type: ignore From 374fc17ffd1aa6d05fdf14b57a2e3aa486cfcc67 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 9 Oct 2024 09:54:52 -0400 Subject: [PATCH 05/19] Adding in more experiments for the fed-dgga strategy applied to more complex FL strategies. Also make a few small tweaks to adaptive_pfl. --- examples/feddg_ga_example/server.py | 3 + research/cifar10/adaptive_pfl/ditto/client.py | 2 +- research/cifar10/adaptive_pfl/ditto/server.py | 32 +-- .../cifar10/adaptive_pfl/fedprox/client.py | 2 +- .../cifar10/adaptive_pfl/fedprox/server.py | 2 +- .../adaptive_pfl/fenda_ditto/client.py | 2 +- .../fenda_ditto/run_fold_experiment.slrm | 40 ++-- .../adaptive_pfl/fenda_ditto/server.py | 34 +--- research/cifar10/adaptive_pfl/mrmtl/client.py | 2 +- research/cifar10/adaptive_pfl/mrmtl/server.py | 2 +- research/cifar10/fed_dgga_pfl/ditto/client.py | 164 +++++++++++++++ .../cifar10/fed_dgga_pfl/ditto/config.yaml | 11 + .../ditto/run_fold_experiment.slrm | 169 ++++++++++++++++ .../fed_dgga_pfl/ditto/run_hp_sweep.sh | 73 +++++++ research/cifar10/fed_dgga_pfl/ditto/server.py | 141 +++++++++++++ research/cifar10/fed_dgga_pfl/fenda/client.py | 162 +++++++++++++++ .../cifar10/fed_dgga_pfl/fenda/config.yaml | 11 + .../fenda/run_fold_experiment.slrm | 168 ++++++++++++++++ .../fed_dgga_pfl/fenda/run_hp_sweep.sh | 70 +++++++ research/cifar10/fed_dgga_pfl/fenda/server.py | 133 ++++++++++++ .../fed_dgga_pfl/fenda_ditto/client.py | 179 +++++++++++++++++ .../fed_dgga_pfl/fenda_ditto/config.yaml | 11 + .../fenda_ditto/run_fold_experiment.slrm | 189 ++++++++++++++++++ .../fed_dgga_pfl/fenda_ditto/run_hp_sweep.sh | 81 ++++++++ .../fed_dgga_pfl/fenda_ditto/server.py | 140 +++++++++++++ 25 files changed, 1740 insertions(+), 83 deletions(-) create mode 100644 research/cifar10/fed_dgga_pfl/ditto/client.py create mode 100644 research/cifar10/fed_dgga_pfl/ditto/config.yaml create mode 100644 research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm create mode 100755 research/cifar10/fed_dgga_pfl/ditto/run_hp_sweep.sh create mode 100644 research/cifar10/fed_dgga_pfl/ditto/server.py create mode 100644 research/cifar10/fed_dgga_pfl/fenda/client.py create mode 100644 research/cifar10/fed_dgga_pfl/fenda/config.yaml create mode 100644 research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm create mode 100755 research/cifar10/fed_dgga_pfl/fenda/run_hp_sweep.sh create mode 100644 research/cifar10/fed_dgga_pfl/fenda/server.py create mode 100644 research/cifar10/fed_dgga_pfl/fenda_ditto/client.py create mode 100644 research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml create mode 100644 research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm create mode 100755 research/cifar10/fed_dgga_pfl/fenda_ditto/run_hp_sweep.sh create mode 100644 research/cifar10/fed_dgga_pfl/fenda_ditto/server.py diff --git a/examples/feddg_ga_example/server.py b/examples/feddg_ga_example/server.py index d78226247..b163d24f5 100644 --- a/examples/feddg_ga_example/server.py +++ b/examples/feddg_ga_example/server.py @@ -24,6 +24,7 @@ def fit_config( local_epochs: Optional[int] = None, local_steps: Optional[int] = None, evaluate_after_fit: bool = False, + pack_losses_with_val_metrics: bool = False, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -31,6 +32,7 @@ def fit_config( "batch_size": batch_size, "n_server_rounds": n_server_rounds, "evaluate_after_fit": evaluate_after_fit, + "pack_losses_with_val_metrics": pack_losses_with_val_metrics, } @@ -43,6 +45,7 @@ def main(config: Dict[str, Any]) -> None: local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), evaluate_after_fit=config.get("evaluate_after_fit", False), + pack_losses_with_val_metrics=config.get("pack_losses_with_val_metrics", False), ) initial_model = ApflModule(MnistNetWithBnAndFrozen()) diff --git a/research/cifar10/adaptive_pfl/ditto/client.py b/research/cifar10/adaptive_pfl/ditto/client.py index f0a916ed4..0800f2913 100644 --- a/research/cifar10/adaptive_pfl/ditto/client.py +++ b/research/cifar10/adaptive_pfl/ditto/client.py @@ -65,7 +65,7 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: return {"global": global_optimizer, "local": local_optimizer} def get_model(self, config: Config) -> nn.Module: - return ConvNet(in_channels=3, use_bn=False).to(self.device) + return ConvNet(in_channels=3, use_bn=False, dropout=0.1).to(self.device) if __name__ == "__main__": diff --git a/research/cifar10/adaptive_pfl/ditto/server.py b/research/cifar10/adaptive_pfl/ditto/server.py index 8fdeecaaf..e3500627b 100644 --- a/research/cifar10/adaptive_pfl/ditto/server.py +++ b/research/cifar10/adaptive_pfl/ditto/server.py @@ -1,13 +1,12 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, Optional +from typing import Any, Dict import flwr as fl from flwr.common.logger import log from flwr.common.typing import Config -from flwr.server.client_manager import ClientManager, SimpleClientManager -from flwr.server.strategy import Strategy +from flwr.server.client_manager import SimpleClientManager from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -18,29 +17,6 @@ from research.cifar10.personal_server import PersonalServer -class PersonalDittoServer(PersonalServer): - """ - The PersonalServer class is used for FL approaches that only have a sense of a PERSONAL model that is checkpointed - and valid only on the client size of the FL training framework. FL approaches like APFL and FENDA fall under this - category. Each client will have its own model that is specific to its own training. Personal models may have - shared components but the full model is specific to each client. This is distinct from the - FlServerWithCheckpointing class which has a sense of a GLOBAL model checkpointed on the server-side that is - shared by all clients. - """ - - def __init__( - self, - client_manager: ClientManager, - strategy: Optional[Strategy] = None, - ) -> None: - assert isinstance( - strategy, FedAvgWithAdaptiveConstraint - ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" - # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with - # some globally shared weights. So we don't checkpoint a global model - super().__init__(client_manager, strategy) - - def fit_config( batch_size: int, local_epochs: int, @@ -69,7 +45,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei client_manager = SimpleClientManager() # Initializing the model on the server side - model = ConvNet(in_channels=3, use_bn=False) + model = ConvNet(in_channels=3, use_bn=False, dropout=0.1) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], @@ -86,7 +62,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei adapt_loss_weight=adapt_loss_weight, ) - server = PersonalDittoServer(client_manager=client_manager, strategy=strategy) + server = PersonalServer(client_manager=client_manager, strategy=strategy) fl.server.start_server( server=server, diff --git a/research/cifar10/adaptive_pfl/fedprox/client.py b/research/cifar10/adaptive_pfl/fedprox/client.py index ed56117a0..394134876 100644 --- a/research/cifar10/adaptive_pfl/fedprox/client.py +++ b/research/cifar10/adaptive_pfl/fedprox/client.py @@ -63,7 +63,7 @@ def get_optimizer(self, config: Config) -> Optimizer: return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) def get_model(self, config: Config) -> nn.Module: - return ConvNet(in_channels=3, use_bn=False).to(self.device) + return ConvNet(in_channels=3, use_bn=False, dropout=0.1).to(self.device) if __name__ == "__main__": diff --git a/research/cifar10/adaptive_pfl/fedprox/server.py b/research/cifar10/adaptive_pfl/fedprox/server.py index 3206d4ef2..0e75e0bbc 100644 --- a/research/cifar10/adaptive_pfl/fedprox/server.py +++ b/research/cifar10/adaptive_pfl/fedprox/server.py @@ -61,7 +61,7 @@ def main( client_manager = SimpleClientManager() # Initializing the model on the server side - model = ConvNet(in_channels=3, use_bn=False) + model = ConvNet(in_channels=3, use_bn=False, dropout=0.1) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/client.py b/research/cifar10/adaptive_pfl/fenda_ditto/client.py index 4fa23e727..15dc527e1 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/client.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/client.py @@ -71,7 +71,7 @@ def get_model(self, config: Config) -> FendaModel: return ConvNetFendaModel(in_channels=3, use_bn=False).to(self.device) def get_global_model(self, config: Config) -> SequentiallySplitModel: - return ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False).to(self.device) + return ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1).to(self.device) if __name__ == "__main__": diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm index 68c6589ae..e5d58086f 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm @@ -160,29 +160,29 @@ do CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out" echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}" - if [[ ${ADAPT} == "TRUE" ]]; then + if [[ ${FREEZE} == "TRUE" ]]; then nohup python -m research.cifar10.adaptive_pfl.fenda_ditto.client \ - --artifact_dir ${ARTIFACT_DIR} \ - --dataset_dir ${DATASET_DIR} \ - --run_name ${RUN_NAME} \ - --client_number ${c} \ - --learning_rate ${CLIENT_LR} \ - --server_address ${SERVER_ADDRESS} \ - --seed ${SEED} \ - --beta ${CLIENT_BETA} \ - --freeze_global_extractor \ - > ${CLIENT_LOG_PATH} 2>&1 & + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + --freeze_global_extractor \ + > ${CLIENT_LOG_PATH} 2>&1 & else nohup python -m research.cifar10.adaptive_pfl.fenda_ditto.client \ - --artifact_dir ${ARTIFACT_DIR} \ - --dataset_dir ${DATASET_DIR} \ - --run_name ${RUN_NAME} \ - --client_number ${c} \ - --learning_rate ${CLIENT_LR} \ - --server_address ${SERVER_ADDRESS} \ - --seed ${SEED} \ - --beta ${CLIENT_BETA} \ - > ${CLIENT_LOG_PATH} 2>&1 & + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + > ${CLIENT_LOG_PATH} 2>&1 & fi done diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/server.py b/research/cifar10/adaptive_pfl/fenda_ditto/server.py index 767583c80..3a45d46cd 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/server.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/server.py @@ -1,46 +1,22 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, Optional +from typing import Any, Dict import flwr as fl from flwr.common.logger import log from flwr.common.typing import Config -from flwr.server.client_manager import ClientManager, SimpleClientManager -from flwr.server.strategy import Strategy +from flwr.server.client_manager import SimpleClientManager from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters from fl4health.utils.random import set_all_random_seeds -from research.cifar10.model import ConvNet +from research.cifar10.model import ConvNetFendaDittoGlobalModel from research.cifar10.personal_server import PersonalServer -class PersonalFendaDittoServer(PersonalServer): - """ - The PersonalServer class is used for FL approaches that only have a sense of a PERSONAL model that is checkpointed - and valid only on the client size of the FL training framework. FL approaches like APFL and FENDA fall under this - category. Each client will have its own model that is specific to its own training. Personal models may have - shared components but the full model is specific to each client. This is distinct from the - FlServerWithCheckpointing class which has a sense of a GLOBAL model checkpointed on the server-side that is - shared by all clients. - """ - - def __init__( - self, - client_manager: ClientManager, - strategy: Optional[Strategy] = None, - ) -> None: - assert isinstance( - strategy, FedAvgWithAdaptiveConstraint - ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" - # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with - # some globally shared weights. So we don't checkpoint a global model - super().__init__(client_manager, strategy) - - def fit_config( batch_size: int, local_epochs: int, @@ -69,7 +45,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei client_manager = SimpleClientManager() # Initializing the model on the server side - model = ConvNet(in_channels=3, use_bn=False) + model = ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], @@ -86,7 +62,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei adapt_loss_weight=adapt_loss_weight, ) - server = PersonalFendaDittoServer(client_manager=client_manager, strategy=strategy) + server = PersonalServer(client_manager=client_manager, strategy=strategy) fl.server.start_server( server=server, diff --git a/research/cifar10/adaptive_pfl/mrmtl/client.py b/research/cifar10/adaptive_pfl/mrmtl/client.py index 84ff0c9ec..47dee7393 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/client.py +++ b/research/cifar10/adaptive_pfl/mrmtl/client.py @@ -63,7 +63,7 @@ def get_optimizer(self, config: Config) -> Optimizer: return torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) def get_model(self, config: Config) -> nn.Module: - return ConvNet(in_channels=3, use_bn=False).to(self.device) + return ConvNet(in_channels=3, use_bn=False, dropout=0.1).to(self.device) if __name__ == "__main__": diff --git a/research/cifar10/adaptive_pfl/mrmtl/server.py b/research/cifar10/adaptive_pfl/mrmtl/server.py index 8ffca7c81..f965aa813 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/server.py +++ b/research/cifar10/adaptive_pfl/mrmtl/server.py @@ -69,7 +69,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei client_manager = SimpleClientManager() # Initializing the model on the server side - model = ConvNet(in_channels=3, use_bn=False) + model = ConvNet(in_channels=3, use_bn=False, dropout=0.1) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], diff --git a/research/cifar10/fed_dgga_pfl/ditto/client.py b/research/cifar10/fed_dgga_pfl/ditto/client.py new file mode 100644 index 000000000..0800f2913 --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/ditto/client.py @@ -0,0 +1,164 @@ +import argparse +import os +from logging import INFO +from pathlib import Path +from typing import Dict, Optional, Sequence, Tuple + +import flwr as fl +import torch +import torch.nn as nn +from flwr.common.logger import log +from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.clients.ditto_client import DittoClient +from fl4health.utils.config import narrow_dict_type +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import F1, Accuracy, Metric +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNet +from research.cifar10.preprocess import get_preprocessed_data + + +class CifarDittoClient(DittoClient): + def __init__( + self, + data_path: Path, + metrics: Sequence[Metric], + device: torch.device, + client_number: int, + learning_rate: float, + heterogeneity_level: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpointer: Optional[ClientCheckpointModule] = None, + ) -> None: + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpointer=checkpointer, + ) + self.client_number = client_number + self.heterogeneity_level = heterogeneity_level + self.learning_rate: float = learning_rate + + log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") + + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + batch_size = narrow_dict_type(config, "batch_size", int) + train_loader, val_loader, _ = get_preprocessed_data( + self.data_path, self.client_number, batch_size, self.heterogeneity_level + ) + return train_loader, val_loader + + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() + + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) + local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) + return {"global": global_optimizer, "local": local_optimizer} + + def get_model(self, config: Config) -> nn.Module: + return ConvNet(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Client Main") + parser.add_argument( + "--artifact_dir", + action="store", + type=str, + help="Path to save client artifacts such as logs and model checkpoints", + required=True, + ) + parser.add_argument( + "--dataset_dir", + action="store", + type=str, + help="Path to the preprocessed Cifar 10 Dataset", + required=True, + ) + parser.add_argument( + "--run_name", + action="store", + help="Name of the run, model checkpoints will be saved under a subfolder with this name", + required=True, + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address for the clients to communicate with the server through", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--client_number", + action="store", + type=int, + help="Number of the client for dataset loading", + required=True, + ) + parser.add_argument( + "--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1 + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--beta", + action="store", + type=float, + help="Heterogeneity level for the dataset", + required=True, + ) + args = parser.parse_args() + + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Learning Rate: {args.learning_rate}") + log(INFO, f"Beta: {args.beta}") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + # Adding extensive checkpointing for the client + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl" + pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" + post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" + post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" + checkpointer = ClientCheckpointModule( + pre_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + ], + post_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + ], + ) + + data_path = Path(args.dataset_dir) + client = CifarDittoClient( + data_path=data_path, + metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + heterogeneity_level=args.beta, + checkpointer=checkpointer, + ) + + fl.client.start_client(server_address=args.server_address, client=client.to_client()) + client.shutdown() diff --git a/research/cifar10/fed_dgga_pfl/ditto/config.yaml b/research/cifar10/fed_dgga_pfl/ditto/config.yaml new file mode 100644 index 000000000..fb088bf2b --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/ditto/config.yaml @@ -0,0 +1,11 @@ +# Parameters that describe server +n_server_rounds: 50 # The number of rounds to run FL + +# Parameters that describe clients +n_clients: 7 # The number of clients in the FL experiment +local_epochs: 1 # The number of epochs to complete for client +batch_size: 32 # The batch size for client training +# Evaluates model immediately after local training on the validation set (in addition to the training set) +evaluate_after_fit: True +# Packs the measured validation losses with the metrics (required for fed-dgga) +pack_losses_with_val_metrics: True diff --git a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm new file mode 100644 index 000000000..5330439ec --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm @@ -0,0 +1,169 @@ +#!/bin/bash + +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --mem=32G +#SBATCH --partition=a40 +#SBATCH --qos=m2 +#SBATCH --job-name=fl_five_fold_exp +#SBATCH --output=%j_%x.out +#SBATCH --error=%j_%x.err +#SBATCH --time=4:00:00 + +############################################### +# Usage: +# +# sbatch research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ \ +# client_side_learning_rate_value \ +# lambda value \ +# server_address \ +# client_beta \ +# step_size +# +# Example: +# sbatch research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm \ +# research/cifar10/fed_dgga_pfl/ditto/config.yaml \ +# research/cifar10/fed_dgga_pfl/ditto/hp_results/ \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ \ +# 0.0001 \ +# 0.01 \ +# 0.0.0.0:8080\ +# 0.1 \ +# +# Notes: +# 1) The sbatch command above should be run from the top level directory of the repository. +# 2) This example runs ditto. As such the data paths and python launch commands are hardcoded. If you want to change +# the example you run, you need to explicitly modify the code below. +# 3) The logging directories need to ALREADY EXIST. The script does not create them. +############################################### + +# Note: +# ntasks: Total number of processes to use across world +# ntasks-per-node: How many processes each node should create + +# Set NCCL options +# export NCCL_DEBUG=INFO +# NCCL backend to communicate between GPU workers is not provided in vector's cluster. +# Disable this option in slurm. +export NCCL_IB_DISABLE=1 + +if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ + [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then + echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}" + export NCCL_SOCKET_IFNAME=bond0 +fi + +# Process Inputs + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 +CLIENT_LR=$5 +LAM_VALUE=$6 +SERVER_ADDRESS=$7 +CLIENT_BETA=$8 +STEP_SIZE=$9 + +# Create the artifact directory +mkdir "${ARTIFACT_DIR}" + +RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) +SEEDS=(2021 2022 2023 2024 2025) + +echo "Python Venv Path: ${VENV_PATH}" + +echo "World size: ${SLURM_NTASKS}" +echo "Number of nodes: ${SLURM_NNODES}" +NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +echo "GPUs per node: ${NUM_GPUs}" + +# Source the environment +source ${VENV_PATH}bin/activate +echo "Active Environment:" +which python + +for ((i=0; i<${#RUN_NAMES[@]}; i++)); +do + RUN_NAME="${RUN_NAMES[i]}" + SEED="${SEEDS[i]}" + # create the run directory + RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" + echo "Starting Run and logging artifcats at ${RUN_DIR}" + if [ -d "${RUN_DIR}" ] + then + # Directory already exists, we check if the done.out file exists + if [ -f "${RUN_DIR}done.out" ] + then + # Done file already exists so we skip this run + echo "Run already completed. Skipping Run." + continue + else + # Done file doesn't exists (assume pre-emption happened) + # Delete the partially finished contents and start over + echo "Run did not finished correctly. Re-running." + rm -r "${RUN_DIR}" + mkdir "${RUN_DIR}" + fi + else + # Directory doesn't exist yet, so we create it. + echo "Run directory does not exist. Creating it." + mkdir "${RUN_DIR}" + fi + + SERVER_OUTPUT_FILE="${RUN_DIR}server.out" + + # Start the server, divert the outputs to a server file + + echo "Server logging at: ${SERVER_OUTPUT_FILE}" + echo "Launching Server" + + nohup python -m research.cifar10.fed_dgga_pfl.ditto.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + --step_size ${STEP_SIZE} \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + + # Sleep for 20 seconds to allow the server to come up. + sleep 20 + + # Start n number of clients and divert the outputs to their own files + n_clients=7 + for (( c=0; c<${n_clients}; c++ )) + do + CLIENT_NAME="client_${c}" + echo "Launching ${CLIENT_NAME}" + + CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out" + echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}" + nohup python -m research.cifar10.fed_dgga_pfl.ditto.client \ + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + > ${CLIENT_LOG_PATH} 2>&1 & + done + + echo "FL Processes Running" + + wait + + # Create a file that verifies that the Run concluded properly + touch "${RUN_DIR}done.out" + echo "Finished FL Processes" + +done diff --git a/research/cifar10/fed_dgga_pfl/ditto/run_hp_sweep.sh b/research/cifar10/fed_dgga_pfl/ditto/run_hp_sweep.sh new file mode 100755 index 000000000..93f251376 --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/ditto/run_hp_sweep.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +############################################### +# Usage: +# +# ./research/cifar10/fed_dgga_pfl/ditto/run_hp_sweep.sh \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ +# +# Example: +# ./research/cifar10/fed_dgga_pfl/ditto/run_hp_sweep.sh \ +# research/cifar10/fed_dgga_pfl/ditto/config.yaml \ +# research/cifar10/fed_dgga_pfl/ditto \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ +# +# Notes: +# 1) The bash command above should be run from the top level directory of the repository. +############################################### + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 + +LR_VALUES=( 0.0001 0.001 0.01 0.1 ) +# Note: These values must correspond to values for the preprocessed CIFAR datasets +BETA_VALUES=( 0.1 0.5 5.0 ) +LAM_VALUES=( 0.001 1.0 ) +STEP_SIZES=( 0.1 0.2 0.5 ) + +SERVER_PORT=8100 + +# Create sweep folder +SWEEP_DIRECTORY="${ARTIFACT_DIR}hp_sweep_results" +echo "Creating sweep folder at ${SWEEP_DIRECTORY}" +mkdir ${SWEEP_DIRECTORY} + +for BETA_VALUE in "${BETA_VALUES[@]}"; do + echo "Creating folder for beta ${BETA_VALUE}" + mkdir "${SWEEP_DIRECTORY}/beta_${BETA_VALUE}" + for LR_VALUE in "${LR_VALUES[@]}"; + do + for LAM_VALUE in "${LAM_VALUES[@]}"; + do + for STEP_SIZE in "${STEP_SIZES[@]}"; + do + EXPERIMENT_NAME="lr_${LR_VALUE}_beta_${BETA_VALUE}_lam_${LAM_VALUE}_step_${STEP_SIZE}" + echo "Beginning Experiment ${EXPERIMENT_NAME}" + EXPERIMENT_DIRECTORY="${SWEEP_DIRECTORY}/beta_${BETA_VALUE}/${EXPERIMENT_NAME}/" + echo "Creating experiment folder ${EXPERIMENT_DIRECTORY}" + mkdir "${EXPERIMENT_DIRECTORY}" + SERVER_ADDRESS="0.0.0.0:${SERVER_PORT}" + echo "Server Address: ${SERVER_ADDRESS}" + SBATCH_COMMAND="research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm \ + ${SERVER_CONFIG_PATH} \ + ${EXPERIMENT_DIRECTORY} \ + ${DATASET_DIR} \ + ${VENV_PATH} \ + ${LR_VALUE} \ + ${LAM_VALUE} \ + ${SERVER_ADDRESS} \ + ${BETA_VALUE} \ + ${STEP_SIZE}" + sbatch ${SBATCH_COMMAND} + ((SERVER_PORT=SERVER_PORT+1)) + done + done + done +done +echo Experiments Launched diff --git a/research/cifar10/fed_dgga_pfl/ditto/server.py b/research/cifar10/fed_dgga_pfl/ditto/server.py new file mode 100644 index 000000000..172336f71 --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/ditto/server.py @@ -0,0 +1,141 @@ +import argparse +from functools import partial +from logging import INFO +from typing import Any, Dict + +import flwr as fl +from flwr.common.logger import log +from flwr.common.typing import Config + +from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager +from fl4health.strategies.feddg_ga import FairnessMetric, FairnessMetricType +from fl4health.strategies.feddg_ga_with_adaptive_constraint import FedDgGaAdaptiveConstraint +from fl4health.utils.config import load_config +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from fl4health.utils.parameter_extraction import get_all_model_parameters +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNet +from research.cifar10.personal_server import PersonalServer + + +def fit_config( + batch_size: int, + local_epochs: int, + n_server_rounds: int, + n_clients: int, + current_server_round: int, + evaluate_after_fit: bool = False, + pack_losses_with_val_metrics: bool = False, +) -> Config: + return { + "batch_size": batch_size, + "local_epochs": local_epochs, + "n_server_rounds": n_server_rounds, + "n_clients": n_clients, + "current_server_round": current_server_round, + "evaluate_after_fit": evaluate_after_fit, + "pack_losses_with_val_metrics": pack_losses_with_val_metrics, + } + + +def main(config: Dict[str, Any], server_address: str, lam: float, step_size: float) -> None: + # This function will be used to produce a config that is sent to each client to initialize their own environment + fit_config_fn = partial( + fit_config, + config["batch_size"], + config["local_epochs"], + config["n_server_rounds"], + config["n_clients"], + config["evaluate_after_fit"], + evaluate_after_fit=config.get("evaluate_after_fit", False), + pack_losses_with_val_metrics=config.get("pack_losses_with_val_metrics", False), + ) + + # FixedSamplingClientManager is a requirement here because the sampling cannot + # be different between validation and evaluation for FedDG-GA to work. FixedSamplingClientManager + # will return the same sampling until it is told to reset, which in FedDgGaStrategy + # is done right before fit_round. + client_manager = FixedSamplingClientManager() + # Initializing the model on the server side + model = ConvNet(in_channels=3, use_bn=False, dropout=0.1) + + # Define a fairness metric based on the loss associated with the global Ditto model as that is the one being + # aggregated by the server. + ditto_fairness_metric = FairnessMetric(FairnessMetricType.CUSTOM, "val - global_loss", signal=1.0) + + # Server performs simple FedAveraging as its server-side optimization strategy + strategy = FedDgGaAdaptiveConstraint( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + # Server waits for min_available_clients before starting FL rounds + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + # We use the same fit config function, as nothing changes for eval + on_evaluate_config_fn=fit_config_fn, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=get_all_model_parameters(model), + initial_loss_weight=lam, + weight_step_size=step_size, + fairness_metric=ditto_fairness_metric, + ) + + server = PersonalServer(client_manager=client_manager, strategy=strategy) + + fl.server.start_server( + server=server, + server_address=server_address, + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + + log(INFO, "Training Complete") + log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{server.best_aggregated_loss}") + + # Shutdown the server gracefully + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Server Main") + parser.add_argument( + "--config_path", + action="store", + type=str, + help="Path to configuration file.", + default="config.yaml", + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address to be used to communicate with the clients", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01 + ) + parser.add_argument( + "--step_size", + action="store", + type=float, + help="Step size for Fed-DGGA Aggregation. Must be between 0.0 and 1.0. Corresponds to d in the original paper", + required=True, + ) + args = parser.parse_args() + + config = load_config(args.config_path) + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Lambda: {args.lam}") + log(INFO, f"Step Size: {args.step_size}") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + main(config, args.server_address, args.lam, args.step_size) diff --git a/research/cifar10/fed_dgga_pfl/fenda/client.py b/research/cifar10/fed_dgga_pfl/fenda/client.py new file mode 100644 index 000000000..4b5ba659f --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda/client.py @@ -0,0 +1,162 @@ +import argparse +import os +from logging import INFO +from pathlib import Path +from typing import Optional, Sequence, Tuple + +import flwr as fl +import torch +from flwr.common.logger import log +from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.clients.fenda_client import FendaClient +from fl4health.model_bases.fenda_base import FendaModel +from fl4health.utils.config import narrow_dict_type +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import F1, Accuracy, Metric +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNetFendaModel +from research.cifar10.preprocess import get_preprocessed_data + + +class CifarFendaClient(FendaClient): + def __init__( + self, + data_path: Path, + metrics: Sequence[Metric], + device: torch.device, + client_number: int, + learning_rate: float, + heterogeneity_level: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpointer: Optional[ClientCheckpointModule] = None, + ) -> None: + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpointer=checkpointer, + ) + self.client_number = client_number + self.heterogeneity_level = heterogeneity_level + self.learning_rate: float = learning_rate + + log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") + + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + batch_size = narrow_dict_type(config, "batch_size", int) + train_loader, val_loader, _ = get_preprocessed_data( + self.data_path, self.client_number, batch_size, self.heterogeneity_level + ) + return train_loader, val_loader + + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() + + def get_optimizer(self, config: Config) -> Optimizer: + return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) + + def get_model(self, config: Config) -> FendaModel: + return ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Client Main") + parser.add_argument( + "--artifact_dir", + action="store", + type=str, + help="Path to save client artifacts such as logs and model checkpoints", + required=True, + ) + parser.add_argument( + "--dataset_dir", + action="store", + type=str, + help="Path to the preprocessed Cifar 10 Dataset", + required=True, + ) + parser.add_argument( + "--run_name", + action="store", + help="Name of the run, model checkpoints will be saved under a subfolder with this name", + required=True, + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address for the clients to communicate with the server through", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--client_number", + action="store", + type=int, + help="Number of the client for dataset loading", + required=True, + ) + parser.add_argument( + "--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1 + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--beta", + action="store", + type=float, + help="Heterogeneity level for the dataset", + required=True, + ) + args = parser.parse_args() + + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Learning Rate: {args.learning_rate}") + log(INFO, f"Beta: {args.beta}") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + # Adding extensive checkpointing for the client + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl" + pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" + post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" + post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" + checkpointer = ClientCheckpointModule( + pre_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + ], + post_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + ], + ) + + data_path = Path(args.dataset_dir) + client = CifarFendaClient( + data_path=data_path, + metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + heterogeneity_level=args.beta, + checkpointer=checkpointer, + ) + + fl.client.start_client(server_address=args.server_address, client=client.to_client()) + client.shutdown() diff --git a/research/cifar10/fed_dgga_pfl/fenda/config.yaml b/research/cifar10/fed_dgga_pfl/fenda/config.yaml new file mode 100644 index 000000000..fb088bf2b --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda/config.yaml @@ -0,0 +1,11 @@ +# Parameters that describe server +n_server_rounds: 50 # The number of rounds to run FL + +# Parameters that describe clients +n_clients: 7 # The number of clients in the FL experiment +local_epochs: 1 # The number of epochs to complete for client +batch_size: 32 # The batch size for client training +# Evaluates model immediately after local training on the validation set (in addition to the training set) +evaluate_after_fit: True +# Packs the measured validation losses with the metrics (required for fed-dgga) +pack_losses_with_val_metrics: True diff --git a/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm new file mode 100644 index 000000000..e84e7bf2b --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm @@ -0,0 +1,168 @@ +#!/bin/bash + +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --mem=32G +#SBATCH --partition=a40 +#SBATCH --qos=m2 +#SBATCH --job-name=fl_five_fold_exp +#SBATCH --output=%j_%x.out +#SBATCH --error=%j_%x.err +#SBATCH --time=4:00:00 + +############################################### +# Usage: +# +# sbatch research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ \ +# client_side_learning_rate_value \ +# server_address\ +# client_beta \ +# step_size +# +# Example: +# sbatch research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm \ +# research/cifar10/fed_dgga_pfl/fenda/config.yaml \ +# research/cifar10/fed_dgga_pfl/fenda/hp_results/ \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ \ +# 0.0001 \ +# 0.0.0.0:8080 \ +# 0.1 \ +# 0.2 +# +# Notes: +# 1) The sbatch command above should be run from the top level directory of the repository. +# 2) This example runs fenda. As such the data paths and python launch commands are hardcoded. If you want to change +# the example you run, you need to explicitly modify the code below. +# 3) The logging directories need to ALREADY EXIST. The script does not create them. +############################################### + +# Note: +# ntasks: Total number of processes to use across world +# ntasks-per-node: How many processes each node should create + +# Set NCCL options +# export NCCL_DEBUG=INFO +# NCCL backend to communicate between GPU workers is not provided in vector's cluster. +# Disable this option in slurm. +export NCCL_IB_DISABLE=1 + +if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ + [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then + echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}" + export NCCL_SOCKET_IFNAME=bond0 +fi + +# Process Inputs + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 +CLIENT_LR=$5 +SERVER_ADDRESS=$6 +CLIENT_BETA=$7 +STEP_SIZE=$8 + +# Create the artifact directory +mkdir "${ARTIFACT_DIR}" + +RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) +SEEDS=(2021 2022 2023 2024 2025) + +echo "Python Venv Path: ${VENV_PATH}" + +echo "World size: ${SLURM_NTASKS}" +echo "Number of nodes: ${SLURM_NNODES}" +NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +echo "GPUs per node: ${NUM_GPUs}" + +# Source the environment +source ${VENV_PATH}bin/activate +echo "Active Environment:" +which python + +for ((i=0; i<${#RUN_NAMES[@]}; i++)); +do + RUN_NAME="${RUN_NAMES[i]}" + SEED="${SEEDS[i]}" + # create the run directory + RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" + echo "Starting Run and logging artifcats at ${RUN_DIR}" + if [ -d "${RUN_DIR}" ] + then + # Directory already exists, we check if the done.out file exists + if [ -f "${RUN_DIR}done.out" ] + then + # Done file already exists so we skip this run + echo "Run already completed. Skipping Run." + continue + else + # Done file doesn't exists (assume pre-emption happened) + # Delete the partially finished contents and start over + echo "Run did not finished correctly. Re-running." + rm -r "${RUN_DIR}" + mkdir "${RUN_DIR}" + fi + else + # Directory doesn't exist yet, so we create it. + echo "Run directory does not exist. Creating it." + mkdir "${RUN_DIR}" + fi + + SERVER_OUTPUT_FILE="${RUN_DIR}server.out" + + # Start the server, divert the outputs to a server file + + echo "Server logging at: ${SERVER_OUTPUT_FILE}" + echo "Launching Server" + + nohup python -m research.cifar10.fed_dgga_pfl.fenda.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --step_size ${STEP_SIZE} \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + + # Sleep for 20 seconds to allow the server to come up. + sleep 20 + + # Start n number of clients and divert the outputs to their own files + n_clients=7 + for (( c=0; c<${n_clients}; c++ )) + do + CLIENT_NAME="client_${c}" + echo "Launching ${CLIENT_NAME}" + + CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out" + echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}" + + nohup python -m research.cifar10.adaptive_pfl.fenda.client \ + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + > ${CLIENT_LOG_PATH} 2>&1 & + + done + + echo "FL Processes Running" + + wait + + # Create a file that verifies that the Run concluded properly + touch "${RUN_DIR}done.out" + echo "Finished FL Processes" + +done diff --git a/research/cifar10/fed_dgga_pfl/fenda/run_hp_sweep.sh b/research/cifar10/fed_dgga_pfl/fenda/run_hp_sweep.sh new file mode 100755 index 000000000..5d31bdc55 --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda/run_hp_sweep.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +############################################### +# Usage: +# +# ./research/cifar10/fed_dgga_pfl/fenda/run_hp_sweep.sh \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ +# +# Example: +# ./research/cifar10/fed_dgga_pfl/fenda/run_hp_sweep.sh \ +# research/cifar10/fed_dgga_pfl/fenda/config.yaml \ +# research/cifar10/fed_dgga_pfl/fenda \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ +# +# Notes: +# 1) The bash command above should be run from the top level directory of the repository. +############################################### + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 + +LR_VALUES=( 0.0001 0.001 0.01 0.1 ) +# Note: These values must correspond to values for the preprocessed CIFAR datasets +BETA_VALUES=( 0.1 0.5 5.0 ) +STEP_SIZES=( 0.1 0.2 0.5 ) + +SERVER_PORT=8100 + +# Create sweep folder +SWEEP_DIRECTORY="${ARTIFACT_DIR}hp_sweep_results" +echo "Creating sweep folder at ${SWEEP_DIRECTORY}" +mkdir ${SWEEP_DIRECTORY} + +for BETA_VALUE in "${BETA_VALUES[@]}"; do + echo "Creating folder for beta ${BETA_VALUE}" + mkdir "${SWEEP_DIRECTORY}/beta_${BETA_VALUE}" + for LR_VALUE in "${LR_VALUES[@]}"; + do + for STEP_SIZE in "${STEP_SIZES[@]}"; + do + EXPERIMENT_NAME="lr_${LR_VALUE}_beta_${BETA_VALUE}_step_${STEP_SIZE}" + echo "Beginning Experiment ${EXPERIMENT_NAME}" + EXPERIMENT_DIRECTORY="${SWEEP_DIRECTORY}/beta_${BETA_VALUE}/${EXPERIMENT_NAME}/" + echo "Creating experiment folder ${EXPERIMENT_DIRECTORY}" + mkdir "${EXPERIMENT_DIRECTORY}" + SERVER_ADDRESS="0.0.0.0:${SERVER_PORT}" + echo "Server Address: ${SERVER_ADDRESS}" + SBATCH_COMMAND="research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm \ + ${SERVER_CONFIG_PATH} \ + ${EXPERIMENT_DIRECTORY} \ + ${DATASET_DIR} \ + ${VENV_PATH} \ + ${LR_VALUE} \ + ${SERVER_ADDRESS} \ + ${BETA_VALUE} \ + ${STEP_SIZE}" + sbatch ${SBATCH_COMMAND} + ((SERVER_PORT=SERVER_PORT+1)) + done + done + done + done +done +echo Experiments Launched diff --git a/research/cifar10/fed_dgga_pfl/fenda/server.py b/research/cifar10/fed_dgga_pfl/fenda/server.py new file mode 100644 index 000000000..74a8d5171 --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda/server.py @@ -0,0 +1,133 @@ +import argparse +from functools import partial +from logging import INFO +from typing import Any, Dict + +import flwr as fl +from flwr.common.logger import log +from flwr.common.typing import Config + +from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager +from fl4health.strategies.feddg_ga import FairnessMetric, FairnessMetricType, FedDgGaStrategy +from fl4health.utils.config import load_config +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from fl4health.utils.parameter_extraction import get_all_model_parameters +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNetFendaModel +from research.cifar10.personal_server import PersonalServer + + +def fit_config( + batch_size: int, + local_epochs: int, + n_server_rounds: int, + n_clients: int, + current_server_round: int, + evaluate_after_fit: bool = False, + pack_losses_with_val_metrics: bool = False, +) -> Config: + return { + "batch_size": batch_size, + "local_epochs": local_epochs, + "n_server_rounds": n_server_rounds, + "n_clients": n_clients, + "current_server_round": current_server_round, + "evaluate_after_fit": evaluate_after_fit, + "pack_losses_with_val_metrics": pack_losses_with_val_metrics, + } + + +def main(config: Dict[str, Any], server_address: str, step_size: float) -> None: + # This function will be used to produce a config that is sent to each client to initialize their own environment + fit_config_fn = partial( + fit_config, + config["batch_size"], + config["local_epochs"], + config["n_server_rounds"], + config["n_clients"], + evaluate_after_fit=config.get("evaluate_after_fit", False), + pack_losses_with_val_metrics=config.get("pack_losses_with_val_metrics", False), + ) + + # FixedSamplingClientManager is a requirement here because the sampling cannot + # be different between validation and evaluation for FedDG-GA to work. FixedSamplingClientManager + # will return the same sampling until it is told to reset, which in FedDgGaStrategy + # is done right before fit_round. + client_manager = FixedSamplingClientManager() + # Initializing the model on the server side + model = ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1) + + # Define a fairness metric based on the loss associated with the whole FENDA model + fenda_fairness_metric = FairnessMetric(FairnessMetricType.LOSS) + + # Server performs simple FedAveraging as its server-side optimization strategy + strategy = FedDgGaStrategy( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + # Server waits for min_available_clients before starting FL rounds + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + # We use the same fit config function, as nothing changes for eval + on_evaluate_config_fn=fit_config_fn, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=get_all_model_parameters(model), + weight_step_size=step_size, + fairness_metric=fenda_fairness_metric, + ) + + server = PersonalServer(client_manager=client_manager, strategy=strategy) + + fl.server.start_server( + server=server, + server_address=server_address, + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + + log(INFO, "Training Complete") + log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{server.best_aggregated_loss}") + + # Shutdown the server gracefully + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Server Main") + parser.add_argument( + "--config_path", + action="store", + type=str, + help="Path to configuration file.", + default="config.yaml", + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address to be used to communicate with the clients", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--step_size", + action="store", + type=float, + help="Step size for Fed-DGGA Aggregation. Must be between 0.0 and 1.0. Corresponds to d in the original paper", + required=True, + ) + args = parser.parse_args() + + config = load_config(args.config_path) + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Step Size: {args.step_size}") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + main(config, args.server_address, args.step_size) diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py new file mode 100644 index 000000000..d33e6b0da --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py @@ -0,0 +1,179 @@ +import argparse +import os +from logging import INFO +from pathlib import Path +from typing import Dict, Optional, Sequence, Tuple + +import flwr as fl +import torch +from flwr.common.logger import log +from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.clients.fenda_ditto_client import FendaDittoClient +from fl4health.model_bases.fenda_base import FendaModel +from fl4health.model_bases.sequential_split_models import SequentiallySplitModel +from fl4health.utils.config import narrow_dict_type +from fl4health.utils.losses import LossMeterType +from fl4health.utils.metrics import F1, Accuracy, Metric +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNetFendaDittoGlobalModel, ConvNetFendaModel +from research.cifar10.preprocess import get_preprocessed_data + + +class CifarFendaDittoClient(FendaDittoClient): + def __init__( + self, + data_path: Path, + metrics: Sequence[Metric], + device: torch.device, + client_number: int, + learning_rate: float, + heterogeneity_level: float, + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpointer: Optional[ClientCheckpointModule] = None, + freeze_global_feature_extractor: bool = False, + ) -> None: + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpointer=checkpointer, + freeze_global_feature_extractor=freeze_global_feature_extractor, + ) + self.client_number = client_number + self.heterogeneity_level = heterogeneity_level + self.learning_rate: float = learning_rate + + log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") + + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + batch_size = narrow_dict_type(config, "batch_size", int) + train_loader, val_loader, _ = get_preprocessed_data( + self.data_path, self.client_number, batch_size, self.heterogeneity_level + ) + return train_loader, val_loader + + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() + + def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) + local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) + return {"global": global_optimizer, "local": local_optimizer} + + def get_model(self, config: Config) -> FendaModel: + return ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + + def get_global_model(self, config: Config) -> SequentiallySplitModel: + return ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Client Main") + parser.add_argument( + "--artifact_dir", + action="store", + type=str, + help="Path to save client artifacts such as logs and model checkpoints", + required=True, + ) + parser.add_argument( + "--dataset_dir", + action="store", + type=str, + help="Path to the preprocessed Cifar 10 Dataset", + required=True, + ) + parser.add_argument( + "--run_name", + action="store", + help="Name of the run, model checkpoints will be saved under a subfolder with this name", + required=True, + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address for the clients to communicate with the server through", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--client_number", + action="store", + type=int, + help="Number of the client for dataset loading", + required=True, + ) + parser.add_argument( + "--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1 + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--beta", + action="store", + type=float, + help="Heterogeneity level for the dataset", + required=True, + ) + parser.add_argument( + "--freeze_global_extractor", + action="store_true", + help="Whether or not to freeze the global feature extractor of the FENDA model or not.", + default=False, + ) + args = parser.parse_args() + + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Device to be used: {DEVICE}") + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Learning Rate: {args.learning_rate}") + log(INFO, f"Beta: {args.beta}") + if args.freeze_global_extractor: + log(INFO, "Freezing the global feature extractor of the FENDA model") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + # Adding extensive checkpointing for the client + checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) + pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl" + pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" + post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" + post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" + checkpointer = ClientCheckpointModule( + pre_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + ], + post_aggregation=[ + BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + ], + ) + + data_path = Path(args.dataset_dir) + client = CifarFendaDittoClient( + data_path=data_path, + metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + device=DEVICE, + client_number=args.client_number, + learning_rate=args.learning_rate, + heterogeneity_level=args.beta, + checkpointer=checkpointer, + freeze_global_feature_extractor=args.freeze_global_extractor, + ) + + fl.client.start_client(server_address=args.server_address, client=client.to_client()) + client.shutdown() diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml b/research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml new file mode 100644 index 000000000..fb088bf2b --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml @@ -0,0 +1,11 @@ +# Parameters that describe server +n_server_rounds: 50 # The number of rounds to run FL + +# Parameters that describe clients +n_clients: 7 # The number of clients in the FL experiment +local_epochs: 1 # The number of epochs to complete for client +batch_size: 32 # The batch size for client training +# Evaluates model immediately after local training on the validation set (in addition to the training set) +evaluate_after_fit: True +# Packs the measured validation losses with the metrics (required for fed-dgga) +pack_losses_with_val_metrics: True diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm new file mode 100644 index 000000000..18122036e --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm @@ -0,0 +1,189 @@ +#!/bin/bash + +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --mem=32G +#SBATCH --partition=a40 +#SBATCH --qos=m2 +#SBATCH --job-name=fl_five_fold_exp +#SBATCH --output=%j_%x.out +#SBATCH --error=%j_%x.err +#SBATCH --time=4:00:00 + +############################################### +# Usage: +# +# sbatch research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ \ +# client_side_learning_rate_value \ +# lambda value \ +# server_address \ +# client_beta \ +# step_size \ +# freeze +# +# Example: +# sbatch research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm \ +# research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml \ +# research/cifar10/fed_dgga_pfl/fenda_ditto/hp_results/ \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ \ +# 0.0001 \ +# 0.01 \ +# 0.0.0.0:8080 \ +# 0.1 \ +# 0.2 \ +# "TRUE" +# +# Notes: +# 1) The sbatch command above should be run from the top level directory of the repository. +# 2) This example runs fenda_ditto. As such the data paths and python launch commands are hardcoded. If you want to change +# the example you run, you need to explicitly modify the code below. +# 3) The logging directories need to ALREADY EXIST. The script does not create them. +############################################### + +# Note: +# ntasks: Total number of processes to use across world +# ntasks-per-node: How many processes each node should create + +# Set NCCL options +# export NCCL_DEBUG=INFO +# NCCL backend to communicate between GPU workers is not provided in vector's cluster. +# Disable this option in slurm. +export NCCL_IB_DISABLE=1 + +if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ + [[ "${SLURM_JOB_PARTITION}" == "rtx6000" ]]; then + echo export NCCL_SOCKET_IFNAME=bond0 on "${SLURM_JOB_PARTITION}" + export NCCL_SOCKET_IFNAME=bond0 +fi + +# Process Inputs + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 +CLIENT_LR=$5 +LAM_VALUE=$6 +SERVER_ADDRESS=$7 +CLIENT_BETA=$8 +STEP_SIZE=$9 +FREEZE=${10} + +# Create the artifact directory +mkdir "${ARTIFACT_DIR}" + +RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) +SEEDS=(2021 2022 2023 2024 2025) + +echo "Python Venv Path: ${VENV_PATH}" + +echo "World size: ${SLURM_NTASKS}" +echo "Number of nodes: ${SLURM_NNODES}" +NUM_GPUs=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +echo "GPUs per node: ${NUM_GPUs}" + +# Source the environment +source ${VENV_PATH}bin/activate +echo "Active Environment:" +which python + +for ((i=0; i<${#RUN_NAMES[@]}; i++)); +do + RUN_NAME="${RUN_NAMES[i]}" + SEED="${SEEDS[i]}" + # create the run directory + RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" + echo "Starting Run and logging artifcats at ${RUN_DIR}" + if [ -d "${RUN_DIR}" ] + then + # Directory already exists, we check if the done.out file exists + if [ -f "${RUN_DIR}done.out" ] + then + # Done file already exists so we skip this run + echo "Run already completed. Skipping Run." + continue + else + # Done file doesn't exists (assume pre-emption happened) + # Delete the partially finished contents and start over + echo "Run did not finished correctly. Re-running." + rm -r "${RUN_DIR}" + mkdir "${RUN_DIR}" + fi + else + # Directory doesn't exist yet, so we create it. + echo "Run directory does not exist. Creating it." + mkdir "${RUN_DIR}" + fi + + SERVER_OUTPUT_FILE="${RUN_DIR}server.out" + + # Start the server, divert the outputs to a server file + + echo "Server logging at: ${SERVER_OUTPUT_FILE}" + echo "Launching Server" + + nohup python -m research.cifar10.fed_dgga_pfl.fenda_ditto.server \ + --config_path ${SERVER_CONFIG_PATH} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --lam ${LAM_VALUE} \ + --step_size ${STEP_SIZE} \ + > ${SERVER_OUTPUT_FILE} 2>&1 & + + # Sleep for 20 seconds to allow the server to come up. + sleep 20 + + # Start n number of clients and divert the outputs to their own files + n_clients=7 + for (( c=0; c<${n_clients}; c++ )) + do + CLIENT_NAME="client_${c}" + echo "Launching ${CLIENT_NAME}" + + CLIENT_LOG_PATH="${RUN_DIR}${CLIENT_NAME}.out" + echo "${CLIENT_NAME} logging at: ${CLIENT_LOG_PATH}" + + if [[ ${FREEZE} == "TRUE" ]]; then + nohup python -m research.cifar10.adaptive_pfl.fenda_ditto.client \ + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + --run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + --freeze_global_extractor \ + > ${CLIENT_LOG_PATH} 2>&1 & + else + nohup python -m research.cifar10.adaptive_pfl.fenda_ditto.client \ + --artifact_dir ${ARTIFACT_DIR} \ + --dataset_dir ${DATASET_DIR} \ + -run_name ${RUN_NAME} \ + --client_number ${c} \ + --learning_rate ${CLIENT_LR} \ + --server_address ${SERVER_ADDRESS} \ + --seed ${SEED} \ + --beta ${CLIENT_BETA} \ + > ${CLIENT_LOG_PATH} 2>&1 & + fi + + done + + echo "FL Processes Running" + + wait + + # Create a file that verifies that the Run concluded properly + touch "${RUN_DIR}done.out" + echo "Finished FL Processes" + +done diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_hp_sweep.sh b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_hp_sweep.sh new file mode 100755 index 000000000..844b087de --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_hp_sweep.sh @@ -0,0 +1,81 @@ +#!/bin/bash + +############################################### +# Usage: +# +# ./research/cifar10/fed_dgga_pfl/fenda_ditto/run_hp_sweep.sh \ +# path_to_config.yaml \ +# path_to_folder_for_artifacts/ \ +# path_to_folder_for_dataset/ \ +# path_to_desired_venv/ +# +# Example: +# ./research/cifar10/fed_dgga_pfl/fenda_ditto/run_hp_sweep.sh \ +# research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml \ +# research/cifar10/fed_dgga_pfl/fenda_ditto \ +# /datasets/cifar10 \ +# /h/demerson/vector_repositories/fl4health_env/ +# +# Notes: +# 1) The bash command above should be run from the top level directory of the repository. +############################################### + +SERVER_CONFIG_PATH=$1 +ARTIFACT_DIR=$2 +DATASET_DIR=$3 +VENV_PATH=$4 + +LR_VALUES=( 0.0001 0.001 0.01 0.1 ) +# Note: These values must correspond to values for the preprocessed CIFAR datasets +BETA_VALUES=( 0.1 0.5 5.0 ) +LAM_VALUES=( 0.001 1.0 ) +STEP_SIZES=( 0.1 0.2 0.5 ) +FREEZE_FEATURE_EXTRACTOR=( "TRUE" "FALSE" ) + +SERVER_PORT=8100 + +# Create sweep folder +SWEEP_DIRECTORY="${ARTIFACT_DIR}hp_sweep_results" +echo "Creating sweep folder at ${SWEEP_DIRECTORY}" +mkdir ${SWEEP_DIRECTORY} + +for BETA_VALUE in "${BETA_VALUES[@]}"; do + echo "Creating folder for beta ${BETA_VALUE}" + mkdir "${SWEEP_DIRECTORY}/beta_${BETA_VALUE}" + for LR_VALUE in "${LR_VALUES[@]}"; + do + for LAM_VALUE in "${LAM_VALUES[@]}"; + do + for STEP_SIZE in "${STEP_SIZES[@]}"; + do + for FREEZE in "${FREEZE_FEATURE_EXTRACTOR[@]}"; + do + EXPERIMENT_NAME="lr_${LR_VALUE}_beta_${BETA_VALUE}_lam_${LAM_VALUE}_step_${STEP_SIZE}" + if [[ ${FREEZE} == "TRUE" ]]; then + EXPERIMENT_NAME="${EXPERIMENT_NAME}_freeze" + fi + echo "Beginning Experiment ${EXPERIMENT_NAME}" + EXPERIMENT_DIRECTORY="${SWEEP_DIRECTORY}/beta_${BETA_VALUE}/${EXPERIMENT_NAME}/" + echo "Creating experiment folder ${EXPERIMENT_DIRECTORY}" + mkdir "${EXPERIMENT_DIRECTORY}" + SERVER_ADDRESS="0.0.0.0:${SERVER_PORT}" + echo "Server Address: ${SERVER_ADDRESS}" + SBATCH_COMMAND="research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm \ + ${SERVER_CONFIG_PATH} \ + ${EXPERIMENT_DIRECTORY} \ + ${DATASET_DIR} \ + ${VENV_PATH} \ + ${LR_VALUE} \ + ${LAM_VALUE} \ + ${SERVER_ADDRESS} \ + ${BETA_VALUE} \ + ${STEP_SIZE} \ + ${FREEZE}" + sbatch ${SBATCH_COMMAND} + ((SERVER_PORT=SERVER_PORT+1)) + done + done + done + done +done +echo Experiments Launched diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py new file mode 100644 index 000000000..bd991c0ac --- /dev/null +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py @@ -0,0 +1,140 @@ +import argparse +from functools import partial +from logging import INFO +from typing import Any, Dict + +import flwr as fl +from flwr.common.logger import log +from flwr.common.typing import Config + +from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager +from fl4health.strategies.feddg_ga import FairnessMetric, FairnessMetricType +from fl4health.strategies.feddg_ga_with_adaptive_constraint import FedDgGaAdaptiveConstraint +from fl4health.utils.config import load_config +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from fl4health.utils.parameter_extraction import get_all_model_parameters +from fl4health.utils.random import set_all_random_seeds +from research.cifar10.model import ConvNetFendaDittoGlobalModel +from research.cifar10.personal_server import PersonalServer + + +def fit_config( + batch_size: int, + local_epochs: int, + n_server_rounds: int, + n_clients: int, + current_server_round: int, + evaluate_after_fit: bool = False, + pack_losses_with_val_metrics: bool = False, +) -> Config: + return { + "batch_size": batch_size, + "local_epochs": local_epochs, + "n_server_rounds": n_server_rounds, + "n_clients": n_clients, + "current_server_round": current_server_round, + "evaluate_after_fit": evaluate_after_fit, + "pack_losses_with_val_metrics": pack_losses_with_val_metrics, + } + + +def main(config: Dict[str, Any], server_address: str, lam: float, step_size: float) -> None: + # This function will be used to produce a config that is sent to each client to initialize their own environment + fit_config_fn = partial( + fit_config, + config["batch_size"], + config["local_epochs"], + config["n_server_rounds"], + config["n_clients"], + evaluate_after_fit=config.get("evaluate_after_fit", False), + pack_losses_with_val_metrics=config.get("pack_losses_with_val_metrics", False), + ) + + # FixedSamplingClientManager is a requirement here because the sampling cannot + # be different between validation and evaluation for FedDG-GA to work. FixedSamplingClientManager + # will return the same sampling until it is told to reset, which in FedDgGaStrategy + # is done right before fit_round. + client_manager = FixedSamplingClientManager() + # Initializing the model on the server side + model = ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1) + + # Define a fairness metric based on the loss associated with the global Ditto model as that is the one being + # aggregated by the server. + fenda_ditto_fairness_metric = FairnessMetric(FairnessMetricType.CUSTOM, "val - global_loss", signal=1.0) + + # Server performs simple FedAveraging as its server-side optimization strategy + strategy = FedDgGaAdaptiveConstraint( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + # Server waits for min_available_clients before starting FL rounds + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + # We use the same fit config function, as nothing changes for eval + on_evaluate_config_fn=fit_config_fn, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=get_all_model_parameters(model), + initial_loss_weight=lam, + weight_step_size=step_size, + fairness_metric=fenda_ditto_fairness_metric, + ) + + server = PersonalServer(client_manager=client_manager, strategy=strategy) + + fl.server.start_server( + server=server, + server_address=server_address, + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + + log(INFO, "Training Complete") + log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{server.best_aggregated_loss}") + + # Shutdown the server gracefully + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Server Main") + parser.add_argument( + "--config_path", + action="store", + type=str, + help="Path to configuration file.", + default="config.yaml", + ) + parser.add_argument( + "--server_address", + action="store", + type=str, + help="Server Address to be used to communicate with the clients", + default="0.0.0.0:8080", + ) + parser.add_argument( + "--seed", + action="store", + type=int, + help="Seed for the random number generators across python, torch, and numpy", + required=False, + ) + parser.add_argument( + "--lam", action="store", type=float, help="FENDA Ditto loss weight for local model training", default=0.01 + ) + parser.add_argument( + "--step_size", + action="store", + type=float, + help="Step size for Fed-DGGA Aggregation. Must be between 0.0 and 1.0. Corresponds to d in the original paper", + required=True, + ) + args = parser.parse_args() + + config = load_config(args.config_path) + log(INFO, f"Server Address: {args.server_address}") + log(INFO, f"Lambda: {args.lam}") + log(INFO, f"Step Size: {args.step_size}") + + # Set the random seed for reproducibility + set_all_random_seeds(args.seed) + + main(config, args.server_address, args.lam, args.step_size) From 47825d2369a53dbf9c0c53a31b4c3326049f76b4 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 9 Oct 2024 09:59:07 -0400 Subject: [PATCH 06/19] Fed-DGGA strategy class name change to shorten it. --- examples/feddg_ga_example/server.py | 4 +- fl4health/strategies/feddg_ga.py | 2 +- .../feddg_ga_with_adaptive_constraint.py | 4 +- research/cifar10/fed_dgga_pfl/fenda/server.py | 4 +- ..._feddg_ga_strategy.py => test_feddg_ga.py} | 40 +++++++++---------- 5 files changed, 27 insertions(+), 27 deletions(-) rename tests/strategies/{test_feddg_ga_strategy.py => test_feddg_ga.py} (92%) diff --git a/examples/feddg_ga_example/server.py b/examples/feddg_ga_example/server.py index b163d24f5..69dce6623 100644 --- a/examples/feddg_ga_example/server.py +++ b/examples/feddg_ga_example/server.py @@ -10,7 +10,7 @@ from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager from fl4health.model_bases.apfl_base import ApflModule from fl4health.server.base_server import FlServer -from fl4health.strategies.feddg_ga import FedDgGaStrategy +from fl4health.strategies.feddg_ga import FedDgGa from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -51,7 +51,7 @@ def main(config: Dict[str, Any]) -> None: initial_model = ApflModule(MnistNetWithBnAndFrozen()) # Implementation of FedDG-GA as a server side strategy - strategy = FedDgGaStrategy( + strategy = FedDgGa( min_fit_clients=config["n_clients"], min_evaluate_clients=config["n_clients"], # Server waits for min_available_clients before starting FL rounds diff --git a/fl4health/strategies/feddg_ga.py b/fl4health/strategies/feddg_ga.py index e791bbb24..7125c2d79 100644 --- a/fl4health/strategies/feddg_ga.py +++ b/fl4health/strategies/feddg_ga.py @@ -91,7 +91,7 @@ def __init__( self.signal = FairnessMetricType.signal_for_type(metric_type) -class FedDgGaStrategy(FedAvg): +class FedDgGa(FedAvg): def __init__( self, *, diff --git a/fl4health/strategies/feddg_ga_with_adaptive_constraint.py b/fl4health/strategies/feddg_ga_with_adaptive_constraint.py index e99281e50..5640061fd 100644 --- a/fl4health/strategies/feddg_ga_with_adaptive_constraint.py +++ b/fl4health/strategies/feddg_ga_with_adaptive_constraint.py @@ -9,10 +9,10 @@ from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint from fl4health.strategies.aggregate_utils import aggregate_losses -from fl4health.strategies.feddg_ga import FairnessMetric, FairnessMetricType, FedDgGaStrategy +from fl4health.strategies.feddg_ga import FairnessMetric, FairnessMetricType, FedDgGa -class FedDgGaAdaptiveConstraint(FedDgGaStrategy): +class FedDgGaAdaptiveConstraint(FedDgGa): def __init__( self, *, diff --git a/research/cifar10/fed_dgga_pfl/fenda/server.py b/research/cifar10/fed_dgga_pfl/fenda/server.py index 74a8d5171..8306da935 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/server.py +++ b/research/cifar10/fed_dgga_pfl/fenda/server.py @@ -8,7 +8,7 @@ from flwr.common.typing import Config from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager -from fl4health.strategies.feddg_ga import FairnessMetric, FairnessMetricType, FedDgGaStrategy +from fl4health.strategies.feddg_ga import FairnessMetric, FairnessMetricType, FedDgGa from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -61,7 +61,7 @@ def main(config: Dict[str, Any], server_address: str, step_size: float) -> None: fenda_fairness_metric = FairnessMetric(FairnessMetricType.LOSS) # Server performs simple FedAveraging as its server-side optimization strategy - strategy = FedDgGaStrategy( + strategy = FedDgGa( min_fit_clients=config["n_clients"], min_evaluate_clients=config["n_clients"], # Server waits for min_available_clients before starting FL rounds diff --git a/tests/strategies/test_feddg_ga_strategy.py b/tests/strategies/test_feddg_ga.py similarity index 92% rename from tests/strategies/test_feddg_ga_strategy.py rename to tests/strategies/test_feddg_ga.py index 8e3a7d598..6e619375a 100644 --- a/tests/strategies/test_feddg_ga_strategy.py +++ b/tests/strategies/test_feddg_ga.py @@ -9,7 +9,7 @@ from pytest import approx, raises from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager -from fl4health.strategies.feddg_ga import FairnessMetricType, FedDgGaStrategy +from fl4health.strategies.feddg_ga import FairnessMetricType, FedDgGa from tests.test_utils.custom_client_proxy import CustomClientProxy @@ -30,7 +30,7 @@ def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn) assert strategy.num_rounds is None try: @@ -48,7 +48,7 @@ def test_configure_fit_fail() -> None: simple_client_manager = _apply_mocks_to_client_manager(SimpleClientManager()) # Fails with no configure fit - strategy = FedDgGaStrategy() + strategy = FedDgGa() with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -60,7 +60,7 @@ def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), simple_client_manager) @@ -72,7 +72,7 @@ def on_fit_config_fn_1(server_round: int) -> Dict[str, Scalar]: "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_1) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_1) assert strategy.num_rounds is None with raises(AssertionError): @@ -86,7 +86,7 @@ def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_2) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_2) assert strategy.num_rounds is None with raises(AssertionError): @@ -99,7 +99,7 @@ def on_fit_config_fn_3(server_round: int) -> Dict[str, Scalar]: "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_3) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_3) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -111,7 +111,7 @@ def on_fit_config_fn_4(server_round: int) -> Dict[str, Scalar]: "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_4) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_4) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -122,7 +122,7 @@ def on_fit_config_fn_5(server_round: int) -> Dict[str, Scalar]: "evaluate_after_fit": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_5) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_5) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -134,7 +134,7 @@ def on_fit_config_fn_6(server_round: int) -> Dict[str, Scalar]: "pack_losses_with_val_metrics": False, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_6) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_6) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -144,7 +144,7 @@ def test_configure_evaluate_fail() -> None: simple_client_manager = _apply_mocks_to_client_manager(SimpleClientManager()) # Fails with no evaluate fit - strategy = FedDgGaStrategy() + strategy = FedDgGa() with raises(AssertionError): strategy.configure_evaluate(1, Parameters([], ""), fixed_sampling_client_manager) @@ -155,7 +155,7 @@ def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_evaluate_config_fn=on_evaluate_config_fn) + strategy = FedDgGa(on_evaluate_config_fn=on_evaluate_config_fn) with raises(AssertionError): strategy.configure_evaluate(1, Parameters([], ""), simple_client_manager) @@ -165,7 +165,7 @@ def on_evaluate_config_fn_1(server_round: int) -> Dict[str, Scalar]: "foo": 123, } - strategy = FedDgGaStrategy(on_evaluate_config_fn=on_evaluate_config_fn_1) + strategy = FedDgGa(on_evaluate_config_fn=on_evaluate_config_fn_1) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -176,7 +176,7 @@ def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: "pack_losses_with_val_metrics": False, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_2) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_2) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -191,7 +191,7 @@ def test_aggregate_fit_and_aggregate_evaluate() -> None: test_eval_metrics_2 = test_eval_results[1][1].metrics test_initial_adjustment_weight = 1.0 / 3.0 - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.num_rounds = 3 strategy.initial_adjustment_weight = test_initial_adjustment_weight @@ -230,7 +230,7 @@ def test_weight_and_aggregate_results_with_default_weights() -> None: test_cid_2 = test_fit_results[1][0].cid test_initial_adjustment_weight = 1.0 / 3.0 - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.initial_adjustment_weight = test_initial_adjustment_weight aggregated_results = strategy.weight_and_aggregate_results(test_fit_results) @@ -247,7 +247,7 @@ def test_weight_and_aggregate_results_with_existing_weights() -> None: test_cid_2 = test_fit_results[1][0].cid test_adjustment_weights = {test_cid_1: 0.21, test_cid_2: 0.76} - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.adjustment_weights = deepcopy(test_adjustment_weights) aggregated_results = strategy.weight_and_aggregate_results(test_fit_results) @@ -260,7 +260,7 @@ def test_update_weights_by_ga() -> None: test_val_loss_key = FairnessMetricType.LOSS.value test_initial_adjustment_weight = 1.0 / 3.0 - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.num_rounds = 3 strategy.initial_adjustment_weight = test_initial_adjustment_weight strategy.train_metrics = { @@ -289,7 +289,7 @@ def test_update_weights_by_ga_with_same_metrics() -> None: test_val_loss_key = FairnessMetricType.LOSS.value test_initial_adjustment_weight = 1.0 / 3.0 - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.num_rounds = 3 strategy.initial_adjustment_weight = test_initial_adjustment_weight strategy.train_metrics = { @@ -311,7 +311,7 @@ def test_update_weights_by_ga_with_same_metrics() -> None: def test_get_current_weight_step_size() -> None: - strategy = FedDgGaStrategy() + strategy = FedDgGa() with raises(AssertionError): strategy.get_current_weight_step_size(2) From 749d8b390a19868b67bc6ea6da495c8c5151e807 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 9 Oct 2024 12:59:36 -0400 Subject: [PATCH 07/19] Some small tweaks to fix a bug in unpacking the test metrics --- fl4health/clients/basic_client.py | 5 +-- fl4health/utils/metrics.py | 2 +- .../load_from_checkpoint_example/README.md | 34 ------------------- 3 files changed, 4 insertions(+), 37 deletions(-) delete mode 100644 tests/smoke_tests/load_from_checkpoint_example/README.md diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index e69170d00..7e937e37f 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -24,7 +24,7 @@ from fl4health.reporting.metrics import MetricsReporter from fl4health.utils.config import narrow_dict_type, narrow_dict_type_and_set_attribute from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType, TrainingLosses -from fl4health.utils.metrics import TEST_NUM_EXAMPLES_KEY, Metric, MetricManager, MetricPrefix +from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, Metric, MetricManager, MetricPrefix from fl4health.utils.random import generate_hash from fl4health.utils.typing import LogLevel, TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -851,7 +851,7 @@ def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict include_losses_in_metrics=include_losses_in_metrics, ) if self.test_loader: - _, test_metrics = self._validate_or_test( + test_loss, test_metrics = self._validate_or_test( self.test_loader, self.test_loss_meter, self.test_metric_manager, @@ -861,6 +861,7 @@ def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict # There will be no clashes due to the naming convention associated with the metric managers if self.num_test_samples is not None: val_metrics[TEST_NUM_EXAMPLES_KEY] = self.num_test_samples + val_metrics[TEST_LOSS_KEY] = test_loss val_metrics.update(test_metrics) return val_loss, val_metrics diff --git a/fl4health/utils/metrics.py b/fl4health/utils/metrics.py index 77cab3567..74af911b6 100644 --- a/fl4health/utils/metrics.py +++ b/fl4health/utils/metrics.py @@ -18,7 +18,7 @@ class MetricPrefix(Enum): TEST_NUM_EXAMPLES_KEY = f"{MetricPrefix.TEST_PREFIX.value} num_examples" -TEST_LOSS_KEY = f"{MetricPrefix.TEST_PREFIX.value} loss" +TEST_LOSS_KEY = f"{MetricPrefix.TEST_PREFIX.value} checkpoint" class Metric(ABC): diff --git a/tests/smoke_tests/load_from_checkpoint_example/README.md b/tests/smoke_tests/load_from_checkpoint_example/README.md deleted file mode 100644 index bb2705159..000000000 --- a/tests/smoke_tests/load_from_checkpoint_example/README.md +++ /dev/null @@ -1,34 +0,0 @@ -# Basic Federated Learning Example -This example provides an very simple implementation of a federated learning training setup on the CIFAR dataset. The -FL server expects two clients to be spun up (i.e. it will wait until two clients report in before starting training). -Each client has the same "local" dataset. I.e. they each load the complete CIFAR dataset and therefore have the same -training and validation sets. The server has some custom metrics aggregation, but is otherwise a vanilla FL -implementation using FedAvg as the server side optimization. - -## Running the Example -In order to run the example, first ensure you have [installed the dependencies in your virtual environment according to the main README](/README.md#development-requirements) and it has been activated. - -## Starting Server - -The next step is to start the server by running -``` -python -m examples.basic_example.server --config_path /path/to/config.yaml -``` -from the FL4Health directory. The following arguments must be present in the specified config file: -* `n_clients`: number of clients the server waits for in order to run the FL training -* `local_epochs`: number of epochs each client will train for locally -* `batch_size`: size of the batches each client will train on -* `n_server_rounds`: The number of rounds to run FL - -## Starting Clients - -Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the two -clients. This is done by simply running (remembering to activate your environment) -``` -python -m examples.basic_example.client --dataset_path /path/to/data -``` -**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If -the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be -automatically downloaded to the path specified and used in the run. - -After both clients have been started federated learning should commence. From dd37afb39300e58d1a0676ef5776fcee40f8cdd4 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:48:37 -0400 Subject: [PATCH 08/19] Fixing a small bug with the smoke test --- tests/smoke_tests/basic_client_metrics.json | 4 ++-- tests/smoke_tests/basic_server_metrics.json | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/smoke_tests/basic_client_metrics.json b/tests/smoke_tests/basic_client_metrics.json index bb85cf85f..c5ace0083 100644 --- a/tests/smoke_tests/basic_client_metrics.json +++ b/tests/smoke_tests/basic_client_metrics.json @@ -18,7 +18,7 @@ "evaluate_metrics": { "val - prediction - accuracy": 0.0942, "test - num_examples": 10000, - "test - loss": 2.30616, + "test - checkpoint": 2.30616, "test - prediction - accuracy": 0.0966 }, "loss": 2.3042 @@ -33,7 +33,7 @@ "evaluate_metrics": { "val - prediction - accuracy": 0.0936, "test - num_examples": 10000, - "test - loss": 2.30109, + "test - checkpoint": 2.30109, "test - prediction - accuracy": 0.0972 }, "loss": 2.2999 diff --git a/tests/smoke_tests/basic_server_metrics.json b/tests/smoke_tests/basic_server_metrics.json index 1d4c55d4e..bcf8e567c 100644 --- a/tests/smoke_tests/basic_server_metrics.json +++ b/tests/smoke_tests/basic_server_metrics.json @@ -4,7 +4,7 @@ "metrics_aggregated": { "val - prediction - accuracy": 0.1031, "test - prediction - accuracy": 0.1039, - "test - loss - aggregated": 2.3613 + "test - checkpoint - aggregated": 2.3613 }, "loss_aggregated": 2.3567 }, @@ -12,7 +12,7 @@ "metrics_aggregated": { "val - prediction - accuracy": 0.0942, "test - prediction - accuracy": 0.0966, - "test - loss - aggregated": 2.3061 + "test - checkpoint - aggregated": 2.3061 }, "loss_aggregated": 2.3042 }, @@ -20,7 +20,7 @@ "metrics_aggregated": { "val - prediction - accuracy": 0.0936, "test - prediction - accuracy": 0.0972, - "test - loss - aggregated": 2.3010 + "test - checkpoint - aggregated": 2.3010 }, "loss_aggregated": 2.2999 } From ef119d44c468b652bae9ebb4ce3ff407398ab675 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:47:59 -0400 Subject: [PATCH 09/19] Fixing small smoke test bug. --- tests/smoke_tests/basic_server_metrics.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/smoke_tests/basic_server_metrics.json b/tests/smoke_tests/basic_server_metrics.json index bcf8e567c..1d4c55d4e 100644 --- a/tests/smoke_tests/basic_server_metrics.json +++ b/tests/smoke_tests/basic_server_metrics.json @@ -4,7 +4,7 @@ "metrics_aggregated": { "val - prediction - accuracy": 0.1031, "test - prediction - accuracy": 0.1039, - "test - checkpoint - aggregated": 2.3613 + "test - loss - aggregated": 2.3613 }, "loss_aggregated": 2.3567 }, @@ -12,7 +12,7 @@ "metrics_aggregated": { "val - prediction - accuracy": 0.0942, "test - prediction - accuracy": 0.0966, - "test - checkpoint - aggregated": 2.3061 + "test - loss - aggregated": 2.3061 }, "loss_aggregated": 2.3042 }, @@ -20,7 +20,7 @@ "metrics_aggregated": { "val - prediction - accuracy": 0.0936, "test - prediction - accuracy": 0.0972, - "test - checkpoint - aggregated": 2.3010 + "test - loss - aggregated": 2.3010 }, "loss_aggregated": 2.2999 } From 2554e5ace96fec04f00e245b3f00dbae1506a4b7 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:10:44 -0400 Subject: [PATCH 10/19] Fix up the Fed DG-GA configuration to include required parameter --- tests/smoke_tests/feddg_ga_config.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/smoke_tests/feddg_ga_config.yaml b/tests/smoke_tests/feddg_ga_config.yaml index 50c7aa8c0..71fb1e2cb 100644 --- a/tests/smoke_tests/feddg_ga_config.yaml +++ b/tests/smoke_tests/feddg_ga_config.yaml @@ -5,4 +5,7 @@ n_server_rounds: 3 # The number of rounds to run FL n_clients: 2 # The number of clients in the FL experiment local_steps: 5 # The number of local steps (one per batch) to complete for client batch_size: 128 # The batch size for client training -evaluate_after_fit: True # Evaluates model immediately after local training on the validation set (in addition to the training set) +# Evaluates model immediately after local training on the validation set (in addition to the training set) +evaluate_after_fit: True +# Packs the measured validation losses with the metrics (required for fed-dgga) +pack_losses_with_val_metrics: True From 4bde4a45e292e7500051ca63f4431201fd1a27c6 Mon Sep 17 00:00:00 2001 From: David Emerson Date: Mon, 28 Oct 2024 13:20:02 -0400 Subject: [PATCH 11/19] Dataset preprocessing script, setting max retries to be 'infinite.' Minor fix to the spacing of seeds in the slrm scripts to be safe. Fixing a bug in the preprocess.slrm, setting placeholders to be defaults in the preprocess_all.sh --- .../cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm | 2 +- .../cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm | 2 +- .../adaptive_pfl/fenda_ditto/run_fold_experiment.slrm | 2 +- .../cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm | 2 +- research/cifar10/ditto/run_fold_experiment.slrm | 2 +- research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm | 2 +- research/cifar10/ditto_mkmmd/run_fold_experiment.slrm | 2 +- .../cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm | 2 +- .../cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm | 2 +- .../fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm | 2 +- research/cifar10/fedavg/run_fold_experiment.slrm | 2 +- research/cifar10/pfl_preprocess_scripts/preprocess.slrm | 6 +----- research/cifar10/pfl_preprocess_scripts/preprocess_all.sh | 8 ++++---- research/cifar10/preprocess.py | 8 ++++---- .../fed_heart_disease/ditto/run_fold_experiment.slrm | 2 +- .../fed_heart_disease/fedper/run_fold_experiment.slrm | 2 +- .../fed_heart_disease/fenda/run_fold_experiment.slrm | 2 +- .../fed_heart_disease/moon/run_fold_experiment.slrm | 2 +- .../fed_heart_disease/perfcl/run_fold_experiment.slrm | 2 +- .../flamby/fed_isic2019/ditto/run_fold_experiment.slrm | 2 +- .../fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm | 2 +- .../fed_isic2019/ditto_mkmmd/run_fold_experiment.slrm | 2 +- .../flamby/fed_isic2019/fedper/run_fold_experiment.slrm | 2 +- .../flamby/fed_isic2019/fenda/run_fold_experiment.slrm | 2 +- .../flamby/fed_isic2019/moon/run_fold_experiment.slrm | 2 +- .../fed_isic2019/mr_mtl_mkmmd/run_fold_experiment.slrm | 2 +- .../flamby/fed_isic2019/perfcl/run_fold_experiment.slrm | 2 +- research/flamby/fed_ixi/ditto/run_fold_experiment.slrm | 2 +- research/flamby/fed_ixi/fedper/run_fold_experiment.slrm | 2 +- research/flamby/fed_ixi/fenda/run_fold_experiment.slrm | 2 +- research/flamby/fed_ixi/moon/run_fold_experiment.slrm | 2 +- research/flamby/fed_ixi/perfcl/run_fold_experiment.slrm | 2 +- 32 files changed, 38 insertions(+), 42 deletions(-) mode change 100644 => 100755 research/cifar10/pfl_preprocess_scripts/preprocess_all.sh diff --git a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm index 05749e8b0..4a1d81448 100644 --- a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm @@ -78,7 +78,7 @@ ADAPT=$9 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm index 47c3dbaf1..4187e63a5 100644 --- a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm @@ -78,7 +78,7 @@ ADAPT=$9 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm index e5d58086f..aec4240e2 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm @@ -81,7 +81,7 @@ FREEZE=${10} mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm index 77db6de76..7c17f3762 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm @@ -78,7 +78,7 @@ ADAPT=$9 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/ditto/run_fold_experiment.slrm b/research/cifar10/ditto/run_fold_experiment.slrm index 718c745f7..f025f409a 100644 --- a/research/cifar10/ditto/run_fold_experiment.slrm +++ b/research/cifar10/ditto/run_fold_experiment.slrm @@ -75,7 +75,7 @@ CLIENT_BETA=$8 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm b/research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm index e9ab0c713..815c326f7 100644 --- a/research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm +++ b/research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm @@ -81,7 +81,7 @@ CLIENT_BETA=${10} mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/ditto_mkmmd/run_fold_experiment.slrm b/research/cifar10/ditto_mkmmd/run_fold_experiment.slrm index a324107b7..9bc11c7b6 100644 --- a/research/cifar10/ditto_mkmmd/run_fold_experiment.slrm +++ b/research/cifar10/ditto_mkmmd/run_fold_experiment.slrm @@ -87,7 +87,7 @@ CLIENT_BETA=${12} mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm index 5330439ec..d31b6031f 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm @@ -77,7 +77,7 @@ STEP_SIZE=$9 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm index e84e7bf2b..610c57022 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm @@ -75,7 +75,7 @@ STEP_SIZE=$8 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm index 18122036e..2c74e088d 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm @@ -81,7 +81,7 @@ FREEZE=${10} mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/fedavg/run_fold_experiment.slrm b/research/cifar10/fedavg/run_fold_experiment.slrm index 551815dc1..bb9fcec5a 100644 --- a/research/cifar10/fedavg/run_fold_experiment.slrm +++ b/research/cifar10/fedavg/run_fold_experiment.slrm @@ -72,7 +72,7 @@ CLIENT_BETA=$7 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/pfl_preprocess_scripts/preprocess.slrm b/research/cifar10/pfl_preprocess_scripts/preprocess.slrm index ca04f463d..9a1792293 100644 --- a/research/cifar10/pfl_preprocess_scripts/preprocess.slrm +++ b/research/cifar10/pfl_preprocess_scripts/preprocess.slrm @@ -45,10 +45,6 @@ echo "Dirichlet Beta: ${BETA}" echo "Number of partitions to produce: ${NUM_PARTITIONS}" echo "Logs being placed in: ${LOG_DIR}" -SERVER_ADDRESS="${SLURMD_NODENAME}:${SERVER_PORT}" - -echo "Server Address: ${SERVER_ADDRESS}" - LOG_PATH="${LOG_DIR}preprocess_${BETA}_${NUM_PARTITIONS}_${SEED}.log" echo "World size: ${SLURM_NTASKS}" @@ -61,7 +57,7 @@ source ${VENV_PATH}bin/activate echo "Active Environment" which python -nohup python -m research.cifar10.preprocess.py \ +nohup python -m research.cifar10.preprocess \ --dataset_dir ${DATASET_DIR} \ --save_dataset_dir ${OUTPUT_DIR} \ --seed ${SEED} \ diff --git a/research/cifar10/pfl_preprocess_scripts/preprocess_all.sh b/research/cifar10/pfl_preprocess_scripts/preprocess_all.sh old mode 100644 new mode 100755 index 526e131e5..2c5d5b7d6 --- a/research/cifar10/pfl_preprocess_scripts/preprocess_all.sh +++ b/research/cifar10/pfl_preprocess_scripts/preprocess_all.sh @@ -6,11 +6,11 @@ SEEDS=( 2024 2025 2026 ) BETAS=( 0.1 0.5 5.0 ) NUM_PARTITIONS=( 7 7 7 ) -ORIGINAL_DATA_DIR="PLACEHOLDER" +ORIGINAL_DATA_DIR="research/cifar10/datasets/cifar10/" DESTINATION_DIRS=( \ - "DEST1" \ - "DEST2" \ - "DEST3" \ + "research/cifar10/datasets/cifar10/" \ + "research/cifar10/datasets/cifar10/" \ + "research/cifar10/datasets/cifar10/" \ ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/cifar10/preprocess.py b/research/cifar10/preprocess.py index b637af0ff..8e3f9632f 100644 --- a/research/cifar10/preprocess.py +++ b/research/cifar10/preprocess.py @@ -100,10 +100,10 @@ def preprocess_data( # Partition train data heterogeneous_partitioner = DirichletLabelBasedAllocation( - number_of_partitions=num_clients, unique_labels=list(range(10)), beta=beta, min_label_examples=2 + number_of_partitions=num_clients, unique_labels=list(range(10)), beta=beta, min_label_examples=1 ) train_partitioned_datasets, train_partitioned_dist = heterogeneous_partitioner.partition_dataset( - training_set, max_retries=5 + training_set, max_retries=-1 ) # Partition validation and test data @@ -111,9 +111,9 @@ def preprocess_data( number_of_partitions=num_clients, unique_labels=list(range(10)), prior_distribution=train_partitioned_dist ) validation_partitioned_datasets, _ = heterogeneous_partitioner_with_prior.partition_dataset( - validation_set, max_retries=5 + validation_set, max_retries=-1 ) - test_partitioned_datasets, _ = heterogeneous_partitioner_with_prior.partition_dataset(test_set, max_retries=5) + test_partitioned_datasets, _ = heterogeneous_partitioner_with_prior.partition_dataset(test_set, max_retries=-1) return train_partitioned_datasets, validation_partitioned_datasets, test_partitioned_datasets diff --git a/research/flamby/fed_heart_disease/ditto/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/ditto/run_fold_experiment.slrm index 1fea649fc..f796bf611 100644 --- a/research/flamby/fed_heart_disease/ditto/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/ditto/run_fold_experiment.slrm @@ -72,7 +72,7 @@ SERVER_ADDRESS=$7 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_heart_disease/fedper/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/fedper/run_fold_experiment.slrm index 257553392..d7c18fcb8 100644 --- a/research/flamby/fed_heart_disease/fedper/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/fedper/run_fold_experiment.slrm @@ -69,7 +69,7 @@ SERVER_ADDRESS=$6 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_heart_disease/fenda/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/fenda/run_fold_experiment.slrm index 31f5d7b88..6b8cb9e45 100644 --- a/research/flamby/fed_heart_disease/fenda/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/fenda/run_fold_experiment.slrm @@ -69,7 +69,7 @@ SERVER_ADDRESS=$6 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_heart_disease/moon/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/moon/run_fold_experiment.slrm index 80d9e42a8..02fc66ef4 100644 --- a/research/flamby/fed_heart_disease/moon/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/moon/run_fold_experiment.slrm @@ -72,7 +72,7 @@ CLIENT_MU=$7 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_heart_disease/perfcl/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/perfcl/run_fold_experiment.slrm index 2fe8ce176..646b51482 100644 --- a/research/flamby/fed_heart_disease/perfcl/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/perfcl/run_fold_experiment.slrm @@ -75,7 +75,7 @@ CLIENT_GAMMA=$8 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_isic2019/ditto/run_fold_experiment.slrm b/research/flamby/fed_isic2019/ditto/run_fold_experiment.slrm index 7cf463cff..278e5e565 100644 --- a/research/flamby/fed_isic2019/ditto/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/ditto/run_fold_experiment.slrm @@ -72,7 +72,7 @@ SERVER_ADDRESS=$7 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm b/research/flamby/fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm index 9ea2fe49c..1a622040f 100644 --- a/research/flamby/fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm @@ -78,7 +78,7 @@ SERVER_ADDRESS=$9 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_isic2019/ditto_mkmmd/run_fold_experiment.slrm b/research/flamby/fed_isic2019/ditto_mkmmd/run_fold_experiment.slrm index 4579cb5e6..648567687 100644 --- a/research/flamby/fed_isic2019/ditto_mkmmd/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/ditto_mkmmd/run_fold_experiment.slrm @@ -81,7 +81,7 @@ SERVER_ADDRESS=${11} mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_isic2019/fedper/run_fold_experiment.slrm b/research/flamby/fed_isic2019/fedper/run_fold_experiment.slrm index af4230aa1..e68609537 100644 --- a/research/flamby/fed_isic2019/fedper/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/fedper/run_fold_experiment.slrm @@ -69,7 +69,7 @@ SERVER_ADDRESS=$6 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_isic2019/fenda/run_fold_experiment.slrm b/research/flamby/fed_isic2019/fenda/run_fold_experiment.slrm index 0a8bb2ac0..e8a25f042 100644 --- a/research/flamby/fed_isic2019/fenda/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/fenda/run_fold_experiment.slrm @@ -69,7 +69,7 @@ SERVER_ADDRESS=$6 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_isic2019/moon/run_fold_experiment.slrm b/research/flamby/fed_isic2019/moon/run_fold_experiment.slrm index 45a50c15a..5514b3083 100644 --- a/research/flamby/fed_isic2019/moon/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/moon/run_fold_experiment.slrm @@ -72,7 +72,7 @@ CLIENT_MU=$7 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_isic2019/mr_mtl_mkmmd/run_fold_experiment.slrm b/research/flamby/fed_isic2019/mr_mtl_mkmmd/run_fold_experiment.slrm index 892711781..e072e977e 100644 --- a/research/flamby/fed_isic2019/mr_mtl_mkmmd/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/mr_mtl_mkmmd/run_fold_experiment.slrm @@ -81,7 +81,7 @@ SERVER_ADDRESS=${11} mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_isic2019/perfcl/run_fold_experiment.slrm b/research/flamby/fed_isic2019/perfcl/run_fold_experiment.slrm index 40a80105a..e27fb8e17 100644 --- a/research/flamby/fed_isic2019/perfcl/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/perfcl/run_fold_experiment.slrm @@ -75,7 +75,7 @@ CLIENT_GAMMA=$8 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_ixi/ditto/run_fold_experiment.slrm b/research/flamby/fed_ixi/ditto/run_fold_experiment.slrm index 72cc00c19..d8155aca6 100644 --- a/research/flamby/fed_ixi/ditto/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/ditto/run_fold_experiment.slrm @@ -72,7 +72,7 @@ SERVER_ADDRESS=$7 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_ixi/fedper/run_fold_experiment.slrm b/research/flamby/fed_ixi/fedper/run_fold_experiment.slrm index d08f1d2a0..13ede0b66 100644 --- a/research/flamby/fed_ixi/fedper/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/fedper/run_fold_experiment.slrm @@ -69,7 +69,7 @@ SERVER_ADDRESS=$6 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_ixi/fenda/run_fold_experiment.slrm b/research/flamby/fed_ixi/fenda/run_fold_experiment.slrm index 4ce177c10..1997fbeda 100644 --- a/research/flamby/fed_ixi/fenda/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/fenda/run_fold_experiment.slrm @@ -69,7 +69,7 @@ SERVER_ADDRESS=$6 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_ixi/moon/run_fold_experiment.slrm b/research/flamby/fed_ixi/moon/run_fold_experiment.slrm index 54d723929..1466eb86b 100644 --- a/research/flamby/fed_ixi/moon/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/moon/run_fold_experiment.slrm @@ -72,7 +72,7 @@ CLIENT_MU=$7 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" diff --git a/research/flamby/fed_ixi/perfcl/run_fold_experiment.slrm b/research/flamby/fed_ixi/perfcl/run_fold_experiment.slrm index c74dc6091..df7a6396f 100644 --- a/research/flamby/fed_ixi/perfcl/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/perfcl/run_fold_experiment.slrm @@ -75,7 +75,7 @@ CLIENT_GAMMA=$8 mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) -SEEDS=(2021 2022 2023 2024 2025) +SEEDS=( 2021 2022 2023 2024 2025 ) echo "Python Venv Path: ${VENV_PATH}" From 3d1cb29fb6858636fc21d467cac5fbb5284f3baa Mon Sep 17 00:00:00 2001 From: David Emerson Date: Mon, 28 Oct 2024 17:08:41 -0400 Subject: [PATCH 12/19] Some bug fixes and reducing the number of rounds --- .../fedprox_server.py | 4 +-- .../tabular_feature_alignment_server.py | 9 +++-- .../cifar10/adaptive_pfl/ditto/config.yaml | 2 +- .../cifar10/adaptive_pfl/fedprox/config.yaml | 2 +- .../cifar10/adaptive_pfl/fedprox/server.py | 1 - .../adaptive_pfl/fenda_ditto/config.yaml | 2 +- .../cifar10/adaptive_pfl/mrmtl/config.yaml | 2 +- research/cifar10/model.py | 34 ++++++++++++++++--- 8 files changed, 39 insertions(+), 17 deletions(-) diff --git a/fl4health/server/adaptive_constraint_servers/fedprox_server.py b/fl4health/server/adaptive_constraint_servers/fedprox_server.py index 4004f81a7..7f0d1a126 100644 --- a/fl4health/server/adaptive_constraint_servers/fedprox_server.py +++ b/fl4health/server/adaptive_constraint_servers/fedprox_server.py @@ -42,8 +42,8 @@ def __init__( criteria. If none, then no server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to checkpoint based on multiple criteria. Defaults to None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the server should send data to before and after each round. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should + send data to before and after each round. """ assert isinstance( strategy, FedAvgWithAdaptiveConstraint diff --git a/fl4health/server/tabular_feature_alignment_server.py b/fl4health/server/tabular_feature_alignment_server.py index cff33517a..f1e867138 100644 --- a/fl4health/server/tabular_feature_alignment_server.py +++ b/fl4health/server/tabular_feature_alignment_server.py @@ -35,15 +35,14 @@ class TabularFeatureAlignmentServer(FlServer): strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. If None the strategy is FedAvg as set by the flwr Server. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform server side checkpointing based on some criteria. If none, then no server-side checkpointing is performed. Defaults to None. tab_features_source_of_truth (Optional[TabularFeaturesInfoEncoder]): The information that is required - for aligning client features. If it is not specified, then the server will randomly poll a client - and gather this information from its data source. + for aligning client features. If it is not specified, then the server will randomly poll a client + and gather this information from its data source. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should + send data to before and after each round. """ def __init__( diff --git a/research/cifar10/adaptive_pfl/ditto/config.yaml b/research/cifar10/adaptive_pfl/ditto/config.yaml index 323b2a693..9b1135ca8 100644 --- a/research/cifar10/adaptive_pfl/ditto/config.yaml +++ b/research/cifar10/adaptive_pfl/ditto/config.yaml @@ -1,5 +1,5 @@ # Parameters that describe server -n_server_rounds: 50 # The number of rounds to run FL +n_server_rounds: 20 # The number of rounds to run FL # Parameters that describe clients n_clients: 7 # The number of clients in the FL experiment diff --git a/research/cifar10/adaptive_pfl/fedprox/config.yaml b/research/cifar10/adaptive_pfl/fedprox/config.yaml index 323b2a693..9b1135ca8 100644 --- a/research/cifar10/adaptive_pfl/fedprox/config.yaml +++ b/research/cifar10/adaptive_pfl/fedprox/config.yaml @@ -1,5 +1,5 @@ # Parameters that describe server -n_server_rounds: 50 # The number of rounds to run FL +n_server_rounds: 20 # The number of rounds to run FL # Parameters that describe clients n_clients: 7 # The number of clients in the FL experiment diff --git a/research/cifar10/adaptive_pfl/fedprox/server.py b/research/cifar10/adaptive_pfl/fedprox/server.py index 0e75e0bbc..5032342a0 100644 --- a/research/cifar10/adaptive_pfl/fedprox/server.py +++ b/research/cifar10/adaptive_pfl/fedprox/server.py @@ -81,7 +81,6 @@ def main( server = FedProxServer( client_manager=client_manager, model=model, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointer, ) diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/config.yaml b/research/cifar10/adaptive_pfl/fenda_ditto/config.yaml index 323b2a693..9b1135ca8 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/config.yaml +++ b/research/cifar10/adaptive_pfl/fenda_ditto/config.yaml @@ -1,5 +1,5 @@ # Parameters that describe server -n_server_rounds: 50 # The number of rounds to run FL +n_server_rounds: 20 # The number of rounds to run FL # Parameters that describe clients n_clients: 7 # The number of clients in the FL experiment diff --git a/research/cifar10/adaptive_pfl/mrmtl/config.yaml b/research/cifar10/adaptive_pfl/mrmtl/config.yaml index 323b2a693..9b1135ca8 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/config.yaml +++ b/research/cifar10/adaptive_pfl/mrmtl/config.yaml @@ -1,5 +1,5 @@ # Parameters that describe server -n_server_rounds: 50 # The number of rounds to run FL +n_server_rounds: 20 # The number of rounds to run FL # Parameters that describe clients n_clients: 7 # The number of clients in the FL experiment diff --git a/research/cifar10/model.py b/research/cifar10/model.py index ef17fda5e..dc9e38725 100644 --- a/research/cifar10/model.py +++ b/research/cifar10/model.py @@ -80,7 +80,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class ConvNetClassifier(ParallelSplitHeadModule): +class ConvNetClassifier(Module): + def __init__( + self, + h: int = 32, + w: int = 32, + hidden: int = 2048, + class_num: int = 10, + dropout: float = 0.0, + ) -> None: + super().__init__() + + self.fc1 = Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden) + self.fc2 = Linear(hidden, class_num) + + self.relu = ReLU(inplace=True) + self.dropout_layer = nn.Dropout(p=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.dropout_layer(x) + x = self.relu(self.fc1(x)) + x = self.dropout_layer(x) + x = self.fc2(x) + + return x + + +class ConvNetFendaClassifier(ParallelSplitHeadModule): def __init__( self, join_mode: ParallelFeatureJoinMode, @@ -127,7 +153,7 @@ def __init__( # is also set to 0 by default for FedIXI local_module = ConvNetFeatureExtractor(in_channels, use_bn) global_module = ConvNetFeatureExtractor(in_channels, use_bn) - model_head = ConvNetClassifier( + model_head = ConvNetFendaClassifier( ParallelFeatureJoinMode.CONCATENATE, h=h, w=w, hidden=hidden, class_num=class_num, dropout=dropout ) super().__init__(local_module=local_module, global_module=global_module, model_head=model_head) @@ -145,7 +171,5 @@ def __init__( dropout: float = 0.0, ) -> None: base_module = ConvNetFeatureExtractor(in_channels, use_bn) - head_module = ConvNetClassifier( - ParallelFeatureJoinMode.CONCATENATE, h=h, w=w, hidden=hidden, class_num=class_num, dropout=dropout - ) + head_module = ConvNetClassifier(h=h, w=w, hidden=hidden, class_num=class_num, dropout=dropout) super().__init__(base_module, head_module, flatten_features=False) From 41b471c2dbb1d60b41024020db18b01d84317088 Mon Sep 17 00:00:00 2001 From: David Emerson Date: Tue, 29 Oct 2024 14:45:55 -0400 Subject: [PATCH 13/19] Reducing the number of server rounds. --- research/cifar10/fed_dgga_pfl/ditto/config.yaml | 2 +- research/cifar10/fed_dgga_pfl/fenda/config.yaml | 2 +- research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/research/cifar10/fed_dgga_pfl/ditto/config.yaml b/research/cifar10/fed_dgga_pfl/ditto/config.yaml index fb088bf2b..07c4be83f 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/config.yaml +++ b/research/cifar10/fed_dgga_pfl/ditto/config.yaml @@ -1,5 +1,5 @@ # Parameters that describe server -n_server_rounds: 50 # The number of rounds to run FL +n_server_rounds: 20 # The number of rounds to run FL # Parameters that describe clients n_clients: 7 # The number of clients in the FL experiment diff --git a/research/cifar10/fed_dgga_pfl/fenda/config.yaml b/research/cifar10/fed_dgga_pfl/fenda/config.yaml index fb088bf2b..07c4be83f 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/config.yaml +++ b/research/cifar10/fed_dgga_pfl/fenda/config.yaml @@ -1,5 +1,5 @@ # Parameters that describe server -n_server_rounds: 50 # The number of rounds to run FL +n_server_rounds: 20 # The number of rounds to run FL # Parameters that describe clients n_clients: 7 # The number of clients in the FL experiment diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml b/research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml index fb088bf2b..07c4be83f 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/config.yaml @@ -1,5 +1,5 @@ # Parameters that describe server -n_server_rounds: 50 # The number of rounds to run FL +n_server_rounds: 20 # The number of rounds to run FL # Parameters that describe clients n_clients: 7 # The number of clients in the FL experiment From d504adaf0ae3be5868a28bda2fa46f33dfa981b5 Mon Sep 17 00:00:00 2001 From: David Emerson Date: Tue, 29 Oct 2024 15:54:42 -0400 Subject: [PATCH 14/19] Adding weighted F1 to measured metrics, bug fix for repeated evaluate_after_fit_argument --- research/cifar10/adaptive_pfl/ditto/client.py | 6 +++++- research/cifar10/adaptive_pfl/fedprox/client.py | 6 +++++- research/cifar10/adaptive_pfl/fenda_ditto/client.py | 6 +++++- research/cifar10/adaptive_pfl/mrmtl/client.py | 6 +++++- research/cifar10/fed_dgga_pfl/ditto/client.py | 6 +++++- research/cifar10/fed_dgga_pfl/ditto/server.py | 1 - research/cifar10/fed_dgga_pfl/fenda/client.py | 6 +++++- research/cifar10/fed_dgga_pfl/fenda_ditto/client.py | 6 +++++- 8 files changed, 35 insertions(+), 8 deletions(-) diff --git a/research/cifar10/adaptive_pfl/ditto/client.py b/research/cifar10/adaptive_pfl/ditto/client.py index 0800f2913..872e1f481 100644 --- a/research/cifar10/adaptive_pfl/ditto/client.py +++ b/research/cifar10/adaptive_pfl/ditto/client.py @@ -152,7 +152,11 @@ def get_model(self, config: Config) -> nn.Module: data_path = Path(args.dataset_dir) client = CifarDittoClient( data_path=data_path, - metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + metrics=[ + Accuracy("accuracy"), + F1("f1_score_macro", average="macro"), + F1("f1_score_weight", average="weighted"), + ], device=DEVICE, client_number=args.client_number, learning_rate=args.learning_rate, diff --git a/research/cifar10/adaptive_pfl/fedprox/client.py b/research/cifar10/adaptive_pfl/fedprox/client.py index 394134876..18be6b303 100644 --- a/research/cifar10/adaptive_pfl/fedprox/client.py +++ b/research/cifar10/adaptive_pfl/fedprox/client.py @@ -156,7 +156,11 @@ def get_model(self, config: Config) -> nn.Module: data_path = Path(args.dataset_dir) client = CifarFedProxClient( data_path=data_path, - metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + metrics=[ + Accuracy("accuracy"), + F1("f1_score_macro", average="macro"), + F1("f1_score_weight", average="weighted"), + ], device=DEVICE, client_number=args.client_number, learning_rate=args.learning_rate, diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/client.py b/research/cifar10/adaptive_pfl/fenda_ditto/client.py index 15dc527e1..e5682050f 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/client.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/client.py @@ -166,7 +166,11 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: data_path = Path(args.dataset_dir) client = CifarFendaDittoClient( data_path=data_path, - metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + metrics=[ + Accuracy("accuracy"), + F1("f1_score_macro", average="macro"), + F1("f1_score_weight", average="weighted"), + ], device=DEVICE, client_number=args.client_number, learning_rate=args.learning_rate, diff --git a/research/cifar10/adaptive_pfl/mrmtl/client.py b/research/cifar10/adaptive_pfl/mrmtl/client.py index 47dee7393..69a5e430d 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/client.py +++ b/research/cifar10/adaptive_pfl/mrmtl/client.py @@ -150,7 +150,11 @@ def get_model(self, config: Config) -> nn.Module: data_path = Path(args.dataset_dir) client = CifarMrMtlClient( data_path=data_path, - metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + metrics=[ + Accuracy("accuracy"), + F1("f1_score_macro", average="macro"), + F1("f1_score_weight", average="weighted"), + ], device=DEVICE, client_number=args.client_number, learning_rate=args.learning_rate, diff --git a/research/cifar10/fed_dgga_pfl/ditto/client.py b/research/cifar10/fed_dgga_pfl/ditto/client.py index 0800f2913..872e1f481 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/ditto/client.py @@ -152,7 +152,11 @@ def get_model(self, config: Config) -> nn.Module: data_path = Path(args.dataset_dir) client = CifarDittoClient( data_path=data_path, - metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + metrics=[ + Accuracy("accuracy"), + F1("f1_score_macro", average="macro"), + F1("f1_score_weight", average="weighted"), + ], device=DEVICE, client_number=args.client_number, learning_rate=args.learning_rate, diff --git a/research/cifar10/fed_dgga_pfl/ditto/server.py b/research/cifar10/fed_dgga_pfl/ditto/server.py index 172336f71..b8be10078 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/server.py +++ b/research/cifar10/fed_dgga_pfl/ditto/server.py @@ -46,7 +46,6 @@ def main(config: Dict[str, Any], server_address: str, lam: float, step_size: flo config["local_epochs"], config["n_server_rounds"], config["n_clients"], - config["evaluate_after_fit"], evaluate_after_fit=config.get("evaluate_after_fit", False), pack_losses_with_val_metrics=config.get("pack_losses_with_val_metrics", False), ) diff --git a/research/cifar10/fed_dgga_pfl/fenda/client.py b/research/cifar10/fed_dgga_pfl/fenda/client.py index 4b5ba659f..1d1abd03e 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda/client.py @@ -150,7 +150,11 @@ def get_model(self, config: Config) -> FendaModel: data_path = Path(args.dataset_dir) client = CifarFendaClient( data_path=data_path, - metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + metrics=[ + Accuracy("accuracy"), + F1("f1_score_macro", average="macro"), + F1("f1_score_weight", average="weighted"), + ], device=DEVICE, client_number=args.client_number, learning_rate=args.learning_rate, diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py index d33e6b0da..298525e71 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py @@ -166,7 +166,11 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: data_path = Path(args.dataset_dir) client = CifarFendaDittoClient( data_path=data_path, - metrics=[Accuracy("accuracy"), F1("F1_Score", average="macro")], + metrics=[ + Accuracy("accuracy"), + F1("f1_score_macro", average="macro"), + F1("f1_score_weight", average="weighted"), + ], device=DEVICE, client_number=args.client_number, learning_rate=args.learning_rate, From 77355a4063cf6521ea023f79b305a1c7bf304c38 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Fri, 1 Nov 2024 12:09:43 -0400 Subject: [PATCH 15/19] Cleaning up some missed variable renames --- research/cifar10/fed_dgga_pfl/ditto/server.py | 2 +- research/cifar10/fed_dgga_pfl/fenda/server.py | 2 +- research/cifar10/fed_dgga_pfl/fenda_ditto/server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/research/cifar10/fed_dgga_pfl/ditto/server.py b/research/cifar10/fed_dgga_pfl/ditto/server.py index b8be10078..fa8b46a7c 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/server.py +++ b/research/cifar10/fed_dgga_pfl/ditto/server.py @@ -75,7 +75,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, step_size: flo evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(model), initial_loss_weight=lam, - weight_step_size=step_size, + adjustment_weight_step_size=step_size, fairness_metric=ditto_fairness_metric, ) diff --git a/research/cifar10/fed_dgga_pfl/fenda/server.py b/research/cifar10/fed_dgga_pfl/fenda/server.py index 8306da935..b9408af6d 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/server.py +++ b/research/cifar10/fed_dgga_pfl/fenda/server.py @@ -72,7 +72,7 @@ def main(config: Dict[str, Any], server_address: str, step_size: float) -> None: fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(model), - weight_step_size=step_size, + adjustment_weight_step_size=step_size, fairness_metric=fenda_fairness_metric, ) diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py index bd991c0ac..2f373d895 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py @@ -75,7 +75,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, step_size: flo evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(model), initial_loss_weight=lam, - weight_step_size=step_size, + adjustment_weight_step_size=step_size, fairness_metric=fenda_ditto_fairness_metric, ) From d364d389166f0962324d6f4653a71ad3a1063b23 Mon Sep 17 00:00:00 2001 From: David Emerson Date: Tue, 5 Nov 2024 09:20:21 -0500 Subject: [PATCH 16/19] Fixing a bit more determinism in the randomness. Switching over to lower gpus to see if everything fits on less in-demand hardware --- fl4health/utils/random.py | 8 ++++++-- .../cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm | 2 +- .../cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm | 2 +- .../adaptive_pfl/fenda_ditto/run_fold_experiment.slrm | 2 +- .../cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm | 2 +- .../cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm | 2 +- .../cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm | 2 +- .../fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm | 2 +- 8 files changed, 13 insertions(+), 9 deletions(-) diff --git a/fl4health/utils/random.py b/fl4health/utils/random.py index b84f2a90b..22e70e156 100644 --- a/fl4health/utils/random.py +++ b/fl4health/utils/random.py @@ -19,10 +19,12 @@ def set_all_random_seeds(seed: Optional[int] = 42) -> None: if seed is None: log(INFO, "No seed provided. Using random seed.") else: - log(INFO, f"Setting seed to {seed}") + log(INFO, f"Setting seed to {seed} and fixing torch determinism") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False def unset_all_random_seeds() -> None: @@ -30,10 +32,12 @@ def unset_all_random_seeds() -> None: Set random seeds for Python random, NumPy, and PyTorch to None. Running this function would undo the effects of set_all_random_seeds. """ - log(INFO, "Setting all random seeds to None.") + log(INFO, "Setting all random seeds to None. Reverting torch determinism settings") random.seed(None) np.random.seed(None) torch.seed() + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.benchmark = True def generate_hash(length: int = 8) -> str: diff --git a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm index 4a1d81448..f170822b2 100644 --- a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm @@ -6,7 +6,7 @@ #SBATCH --cpus-per-task=1 #SBATCH --gres=gpu:1 #SBATCH --mem=32G -#SBATCH --partition=a40 +#SBATCH --partition=rtx6000 #SBATCH --qos=m2 #SBATCH --job-name=fl_five_fold_exp #SBATCH --output=%j_%x.out diff --git a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm index 4187e63a5..59f77dff8 100644 --- a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm @@ -6,7 +6,7 @@ #SBATCH --cpus-per-task=1 #SBATCH --gres=gpu:1 #SBATCH --mem=32G -#SBATCH --partition=a40 +#SBATCH --partition=rtx6000 #SBATCH --qos=m2 #SBATCH --job-name=fl_five_fold_exp #SBATCH --output=%j_%x.out diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm index aec4240e2..58d39133a 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm @@ -6,7 +6,7 @@ #SBATCH --cpus-per-task=1 #SBATCH --gres=gpu:1 #SBATCH --mem=32G -#SBATCH --partition=a40 +#SBATCH --partition=rtx6000 #SBATCH --qos=m2 #SBATCH --job-name=fl_five_fold_exp #SBATCH --output=%j_%x.out diff --git a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm index 7c17f3762..b216abf2b 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm @@ -6,7 +6,7 @@ #SBATCH --cpus-per-task=1 #SBATCH --gres=gpu:1 #SBATCH --mem=32G -#SBATCH --partition=a40 +#SBATCH --partition=rtx6000 #SBATCH --qos=m2 #SBATCH --job-name=fl_five_fold_exp #SBATCH --output=%j_%x.out diff --git a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm index d31b6031f..37ecfafd1 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm @@ -6,7 +6,7 @@ #SBATCH --cpus-per-task=1 #SBATCH --gres=gpu:1 #SBATCH --mem=32G -#SBATCH --partition=a40 +#SBATCH --partition=rtx6000 #SBATCH --qos=m2 #SBATCH --job-name=fl_five_fold_exp #SBATCH --output=%j_%x.out diff --git a/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm index 610c57022..9ca21399c 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm @@ -6,7 +6,7 @@ #SBATCH --cpus-per-task=1 #SBATCH --gres=gpu:1 #SBATCH --mem=32G -#SBATCH --partition=a40 +#SBATCH --partition=rtx6000 #SBATCH --qos=m2 #SBATCH --job-name=fl_five_fold_exp #SBATCH --output=%j_%x.out diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm index 2c74e088d..0dfa94b8f 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm @@ -6,7 +6,7 @@ #SBATCH --cpus-per-task=1 #SBATCH --gres=gpu:1 #SBATCH --mem=32G -#SBATCH --partition=a40 +#SBATCH --partition=rtx6000 #SBATCH --qos=m2 #SBATCH --job-name=fl_five_fold_exp #SBATCH --output=%j_%x.out From 55df53faa9ee6c5faad83475f28523615b539162 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Thu, 7 Nov 2024 08:51:29 -0500 Subject: [PATCH 17/19] Expanding lam to lambda in arguments passed to scripts --- research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm | 4 ++-- research/cifar10/adaptive_pfl/ditto/server.py | 2 +- .../cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm | 4 ++-- research/cifar10/adaptive_pfl/fedprox/server.py | 2 +- .../cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm | 4 ++-- research/cifar10/adaptive_pfl/fenda_ditto/server.py | 2 +- research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm | 4 ++-- research/cifar10/adaptive_pfl/mrmtl/server.py | 2 +- research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm | 2 +- research/cifar10/fed_dgga_pfl/ditto/server.py | 2 +- .../cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm | 2 +- research/cifar10/fed_dgga_pfl/fenda_ditto/server.py | 2 +- 12 files changed, 16 insertions(+), 16 deletions(-) diff --git a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm index f170822b2..aa8e256a0 100644 --- a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm @@ -132,7 +132,7 @@ do --config_path ${SERVER_CONFIG_PATH} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ --use_adaptation \ > ${SERVER_OUTPUT_FILE} 2>&1 & else @@ -140,7 +140,7 @@ do --config_path ${SERVER_CONFIG_PATH} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ > ${SERVER_OUTPUT_FILE} 2>&1 & fi diff --git a/research/cifar10/adaptive_pfl/ditto/server.py b/research/cifar10/adaptive_pfl/ditto/server.py index e3500627b..9d82c9461 100644 --- a/research/cifar10/adaptive_pfl/ditto/server.py +++ b/research/cifar10/adaptive_pfl/ditto/server.py @@ -101,7 +101,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei required=False, ) parser.add_argument( - "--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01 + "--lambda", action="store", type=float, help="Ditto loss weight for local model training", default=0.01 ) parser.add_argument( "--use_adaptation", diff --git a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm index 59f77dff8..f95dae6c6 100644 --- a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm @@ -134,7 +134,7 @@ do --run_name ${RUN_NAME} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ --use_adaptation \ > ${SERVER_OUTPUT_FILE} 2>&1 & else @@ -144,7 +144,7 @@ do --run_name ${RUN_NAME} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ > ${SERVER_OUTPUT_FILE} 2>&1 & fi diff --git a/research/cifar10/adaptive_pfl/fedprox/server.py b/research/cifar10/adaptive_pfl/fedprox/server.py index 5032342a0..ecc177203 100644 --- a/research/cifar10/adaptive_pfl/fedprox/server.py +++ b/research/cifar10/adaptive_pfl/fedprox/server.py @@ -136,7 +136,7 @@ def main( required=False, ) parser.add_argument( - "--lam", action="store", type=float, help="FedProx loss weight for local model training", default=0.01 + "--lambda", action="store", type=float, help="FedProx loss weight for local model training", default=0.01 ) parser.add_argument( "--use_adaptation", diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm index 58d39133a..ac4909432 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm @@ -135,7 +135,7 @@ do --config_path ${SERVER_CONFIG_PATH} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ --use_adaptation \ > ${SERVER_OUTPUT_FILE} 2>&1 & else @@ -143,7 +143,7 @@ do --config_path ${SERVER_CONFIG_PATH} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ > ${SERVER_OUTPUT_FILE} 2>&1 & fi diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/server.py b/research/cifar10/adaptive_pfl/fenda_ditto/server.py index 3a45d46cd..b576852f1 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/server.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/server.py @@ -101,7 +101,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei required=False, ) parser.add_argument( - "--lam", action="store", type=float, help="FENDA Ditto loss weight for local model training", default=0.01 + "--lambda", action="store", type=float, help="FENDA Ditto loss weight for local model training", default=0.01 ) parser.add_argument( "--use_adaptation", diff --git a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm index b216abf2b..2f708bce0 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm @@ -132,7 +132,7 @@ do --config_path ${SERVER_CONFIG_PATH} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ --use_adaptation \ > ${SERVER_OUTPUT_FILE} 2>&1 & else @@ -140,7 +140,7 @@ do --config_path ${SERVER_CONFIG_PATH} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ > ${SERVER_OUTPUT_FILE} 2>&1 & fi diff --git a/research/cifar10/adaptive_pfl/mrmtl/server.py b/research/cifar10/adaptive_pfl/mrmtl/server.py index f965aa813..f3a5e8c60 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/server.py +++ b/research/cifar10/adaptive_pfl/mrmtl/server.py @@ -125,7 +125,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei required=False, ) parser.add_argument( - "--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01 + "--lambda", action="store", type=float, help="Ditto loss weight for local model training", default=0.01 ) parser.add_argument( "--use_adaptation", diff --git a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm index 37ecfafd1..7d95fe7e4 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm @@ -130,7 +130,7 @@ do --config_path ${SERVER_CONFIG_PATH} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ --step_size ${STEP_SIZE} \ > ${SERVER_OUTPUT_FILE} 2>&1 & diff --git a/research/cifar10/fed_dgga_pfl/ditto/server.py b/research/cifar10/fed_dgga_pfl/ditto/server.py index fa8b46a7c..e95b8543c 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/server.py +++ b/research/cifar10/fed_dgga_pfl/ditto/server.py @@ -118,7 +118,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, step_size: flo required=False, ) parser.add_argument( - "--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01 + "--lambda", action="store", type=float, help="Ditto loss weight for local model training", default=0.01 ) parser.add_argument( "--step_size", diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm index 0dfa94b8f..265cf4eee 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm @@ -134,7 +134,7 @@ do --config_path ${SERVER_CONFIG_PATH} \ --server_address ${SERVER_ADDRESS} \ --seed ${SEED} \ - --lam ${LAM_VALUE} \ + --lambda ${LAM_VALUE} \ --step_size ${STEP_SIZE} \ > ${SERVER_OUTPUT_FILE} 2>&1 & diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py index 2f373d895..3a3452cbd 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py @@ -118,7 +118,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, step_size: flo required=False, ) parser.add_argument( - "--lam", action="store", type=float, help="FENDA Ditto loss weight for local model training", default=0.01 + "--lambda", action="store", type=float, help="FENDA Ditto loss weight for local model training", default=0.01 ) parser.add_argument( "--step_size", From edc846307f8e74230277de2b7b6213fc4ae7d6fb Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Thu, 7 Nov 2024 09:00:11 -0500 Subject: [PATCH 18/19] Changes to make max retries a touch more clear and cleaner --- fl4health/utils/partitioners.py | 17 +++++++++++------ research/cifar10/preprocess.py | 6 +++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/fl4health/utils/partitioners.py b/fl4health/utils/partitioners.py index c49c934f9..b9b2492a2 100644 --- a/fl4health/utils/partitioners.py +++ b/fl4health/utils/partitioners.py @@ -148,7 +148,9 @@ def partition_label_indices( # Dropping the last partition as they are "excess" indices return partitioned_indices[:-1], min_samples, partition_allocations - def partition_dataset(self, original_dataset: D, max_retries: int = 5) -> Tuple[List[D], Dict[T, np.ndarray]]: + def partition_dataset( + self, original_dataset: D, max_retries: Optional[int] = 5 + ) -> Tuple[List[D], Dict[T, np.ndarray]]: """ Attempts partitioning of the original dataset up to max_retries times. Retries are potentially required if the user requests a minimum number of labels be assigned to each of the partitions. If the drawn Dirichlet @@ -157,16 +159,19 @@ def partition_dataset(self, original_dataset: D, max_retries: int = 5) -> Tuple[ Args: original_dataset (D): The dataset to be partitioned - max_retries (int, optional): Number of times to attempt to satisfy a user provided minimum - label-associated data points per partition. Defaults to 5. + max_retries (Optional[int], optional): Number of times to attempt to satisfy a user provided minimum + label-associated data points per partition. Set this value to None if you want to retry indefinitely. + Defaults to 5. Raises: ValueError: Throws this error if the retries have been exhausted and the user provided minimum is not met. Returns: - List[D]: The partitioned datasets, length should correspond to self.number_of_partitions - Dict[T, np.ndarray]: The Dirichlet distribution used to partition the data points for each label. + Tuple[List[D], Dict[T, np.ndarray]]: List[D] is the partitioned datasets, length should correspond to + self.number_of_partitions. Dict[T, np.ndarray] is the Dirichlet distribution used to partition the data + points for each label. """ + targets = original_dataset.targets assert targets is not None, "A label-based partitioner requires targets but this dataset has no targets" partitioned_indices = [torch.Tensor([]).int() for _ in range(self.number_of_partitions)] @@ -195,7 +200,7 @@ def partition_dataset(self, original_dataset: D, max_retries: int = 5) -> Tuple[ f"minimum requested was {self.min_label_examples}. Resampling the partition..." ), ) - if partition_attempts == max_retries: + if max_retries is not None and partition_attempts >= max_retries: raise ValueError( ( f"Max Retries: {max_retries} reached. Partitioning failed to " diff --git a/research/cifar10/preprocess.py b/research/cifar10/preprocess.py index 8e3f9632f..58a478a1e 100644 --- a/research/cifar10/preprocess.py +++ b/research/cifar10/preprocess.py @@ -103,7 +103,7 @@ def preprocess_data( number_of_partitions=num_clients, unique_labels=list(range(10)), beta=beta, min_label_examples=1 ) train_partitioned_datasets, train_partitioned_dist = heterogeneous_partitioner.partition_dataset( - training_set, max_retries=-1 + training_set, max_retries=None ) # Partition validation and test data @@ -111,9 +111,9 @@ def preprocess_data( number_of_partitions=num_clients, unique_labels=list(range(10)), prior_distribution=train_partitioned_dist ) validation_partitioned_datasets, _ = heterogeneous_partitioner_with_prior.partition_dataset( - validation_set, max_retries=-1 + validation_set, max_retries=None ) - test_partitioned_datasets, _ = heterogeneous_partitioner_with_prior.partition_dataset(test_set, max_retries=-1) + test_partitioned_datasets, _ = heterogeneous_partitioner_with_prior.partition_dataset(test_set, max_retries=None) return train_partitioned_datasets, validation_partitioned_datasets, test_partitioned_datasets From 93988c79c5349a62e457ffa21bcb6feef1d574ea Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Thu, 7 Nov 2024 11:19:47 -0500 Subject: [PATCH 19/19] Moving the torch specific determinism to separate optional arguments so as not to disrupt any other workflows that don't need them. Added some documentation. Adding the setting of these to the appropriate places for the pFL experiments. Reducing the hidden size of the cnns used in pFL experiments, as they are quite large. --- fl4health/utils/random.py | 34 +++++++++++++++---- research/cifar10/adaptive_pfl/ditto/client.py | 4 +-- .../ditto/run_fold_experiment.slrm | 4 +++ research/cifar10/adaptive_pfl/ditto/server.py | 4 +-- .../cifar10/adaptive_pfl/fedprox/client.py | 4 +-- .../fedprox/run_fold_experiment.slrm | 4 +++ .../cifar10/adaptive_pfl/fedprox/server.py | 4 +-- .../adaptive_pfl/fenda_ditto/client.py | 4 +-- .../fenda_ditto/run_fold_experiment.slrm | 4 +++ .../adaptive_pfl/fenda_ditto/server.py | 4 +-- research/cifar10/adaptive_pfl/mrmtl/client.py | 4 +-- .../mrmtl/run_fold_experiment.slrm | 4 +++ research/cifar10/adaptive_pfl/mrmtl/server.py | 4 +-- research/cifar10/fed_dgga_pfl/ditto/client.py | 4 +-- .../ditto/run_fold_experiment.slrm | 4 +++ research/cifar10/fed_dgga_pfl/ditto/server.py | 4 +-- research/cifar10/fed_dgga_pfl/fenda/client.py | 4 +-- .../fenda/run_fold_experiment.slrm | 4 +++ research/cifar10/fed_dgga_pfl/fenda/server.py | 4 +-- .../fed_dgga_pfl/fenda_ditto/client.py | 6 ++-- .../fenda_ditto/run_fold_experiment.slrm | 4 +++ .../fed_dgga_pfl/fenda_ditto/server.py | 4 +-- 22 files changed, 84 insertions(+), 36 deletions(-) diff --git a/fl4health/utils/random.py b/fl4health/utils/random.py index 22e70e156..515889a67 100644 --- a/fl4health/utils/random.py +++ b/fl4health/utils/random.py @@ -8,22 +8,43 @@ from flwr.common.logger import log -def set_all_random_seeds(seed: Optional[int] = 42) -> None: - """Set seeds for python random, numpy random, and pytorch random. +def set_all_random_seeds( + seed: Optional[int] = 42, use_deterministic_torch_algos: bool = False, disable_torch_benchmarking: bool = False +) -> None: + """ + Set seeds for python random, numpy random, and pytorch random. It also offers the option to force pytorch to use + deterministic algorithm for certain methods and layers see: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html) for more details. Finally, it + allows one to disable cuda benchmarking, which can also affect the determinism of pytorch training outside of + random seeding. For more information on reproducibility in pytorch see: + https://pytorch.org/docs/stable/notes/randomness.html - Will no-op if seed is `None`. + NOTE: If the use_deterministic_torch_algos flag is set to True, you may need to set the environment variable + CUBLAS_WORKSPACE_CONFIG to something like :4096:8, to avoid CUDA errors. Additional documentation may be found + here: https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility Args: - seed (int): The seed value to be used for random number generators. Default is 42. + seed (Optional[int], optional): The seed value to be used for random number generators. Default is 42. Seed + setting will no-op if the seed is explicitly set to None + use_deterministic_torch_algos (bool, optional): Whether or not to set torch.use_deterministic_algorithms to + True. Defaults to False. + disable_torch_benchmarking (bool, optional): Whether to explicitly disable cuda benchmarking in + torch processes. Defaults to False. """ if seed is None: log(INFO, "No seed provided. Using random seed.") else: - log(INFO, f"Setting seed to {seed} and fixing torch determinism") + log(INFO, f"Setting random seeds to {seed}.") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.use_deterministic_algorithms(True) + if use_deterministic_torch_algos: + log(INFO, "Setting torch.use_deterministic_algorithms to True.") + # warn_only is set to true so that layers and components without deterministic algorithms available will + # warn the user that they don't exist, but won't take down the process with an exception. + torch.use_deterministic_algorithms(True, warn_only=True) + if disable_torch_benchmarking: + log(INFO, "Disabling CUDA algorithm benchmarking.") torch.backends.cudnn.benchmark = False @@ -37,7 +58,6 @@ def unset_all_random_seeds() -> None: np.random.seed(None) torch.seed() torch.use_deterministic_algorithms(False) - torch.backends.cudnn.benchmark = True def generate_hash(length: int = 8) -> str: diff --git a/research/cifar10/adaptive_pfl/ditto/client.py b/research/cifar10/adaptive_pfl/ditto/client.py index 872e1f481..a2e510a7c 100644 --- a/research/cifar10/adaptive_pfl/ditto/client.py +++ b/research/cifar10/adaptive_pfl/ditto/client.py @@ -65,7 +65,7 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: return {"global": global_optimizer, "local": local_optimizer} def get_model(self, config: Config) -> nn.Module: - return ConvNet(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + return ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device) if __name__ == "__main__": @@ -130,7 +130,7 @@ def get_model(self, config: Config) -> nn.Module: log(INFO, f"Beta: {args.beta}") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Adding extensive checkpointing for the client checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) diff --git a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm index aa8e256a0..9cc88ce05 100644 --- a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm @@ -62,6 +62,10 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi +# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# in fl4health/utils/random.py for more information +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + # Process Inputs SERVER_CONFIG_PATH=$1 diff --git a/research/cifar10/adaptive_pfl/ditto/server.py b/research/cifar10/adaptive_pfl/ditto/server.py index 9d82c9461..51c4c4a24 100644 --- a/research/cifar10/adaptive_pfl/ditto/server.py +++ b/research/cifar10/adaptive_pfl/ditto/server.py @@ -45,7 +45,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei client_manager = SimpleClientManager() # Initializing the model on the server side - model = ConvNet(in_channels=3, use_bn=False, dropout=0.1) + model = ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], @@ -118,6 +118,6 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei log(INFO, "Adapting the loss weight for model drift via global model loss") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) main(config, args.server_address, args.lam, args.use_adaptation) diff --git a/research/cifar10/adaptive_pfl/fedprox/client.py b/research/cifar10/adaptive_pfl/fedprox/client.py index 18be6b303..624e7e639 100644 --- a/research/cifar10/adaptive_pfl/fedprox/client.py +++ b/research/cifar10/adaptive_pfl/fedprox/client.py @@ -63,7 +63,7 @@ def get_optimizer(self, config: Config) -> Optimizer: return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) def get_model(self, config: Config) -> nn.Module: - return ConvNet(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + return ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device) if __name__ == "__main__": @@ -134,7 +134,7 @@ def get_model(self, config: Config) -> nn.Module: log(INFO, f"Beta: {args.beta}") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Adding extensive checkpointing for the client checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) diff --git a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm index f95dae6c6..56dec6e51 100644 --- a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm @@ -62,6 +62,10 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi +# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# in fl4health/utils/random.py for more information +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + # Process Inputs SERVER_CONFIG_PATH=$1 diff --git a/research/cifar10/adaptive_pfl/fedprox/server.py b/research/cifar10/adaptive_pfl/fedprox/server.py index ecc177203..d8859e9f5 100644 --- a/research/cifar10/adaptive_pfl/fedprox/server.py +++ b/research/cifar10/adaptive_pfl/fedprox/server.py @@ -61,7 +61,7 @@ def main( client_manager = SimpleClientManager() # Initializing the model on the server side - model = ConvNet(in_channels=3, use_bn=False, dropout=0.1) + model = ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], @@ -153,6 +153,6 @@ def main( log(INFO, "Adapting the loss weight for model drift via model loss") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) main(config, args.server_address, args.artifact_dir, args.run_name, args.lam, args.use_adaptation) diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/client.py b/research/cifar10/adaptive_pfl/fenda_ditto/client.py index e5682050f..421fba250 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/client.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/client.py @@ -68,7 +68,7 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: return {"global": global_optimizer, "local": local_optimizer} def get_model(self, config: Config) -> FendaModel: - return ConvNetFendaModel(in_channels=3, use_bn=False).to(self.device) + return ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device) def get_global_model(self, config: Config) -> SequentiallySplitModel: return ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1).to(self.device) @@ -144,7 +144,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: log(INFO, "Freezing the global feature extractor of the FENDA model") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Adding extensive checkpointing for the client checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm index ac4909432..e30981f53 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm @@ -64,6 +64,10 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi +# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# in fl4health/utils/random.py for more information +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + # Process Inputs SERVER_CONFIG_PATH=$1 diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/server.py b/research/cifar10/adaptive_pfl/fenda_ditto/server.py index b576852f1..464491f64 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/server.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/server.py @@ -45,7 +45,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei client_manager = SimpleClientManager() # Initializing the model on the server side - model = ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1) + model = ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1, hidden=512) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], @@ -118,6 +118,6 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei log(INFO, "Adapting the loss weight for model drift via global model loss") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) main(config, args.server_address, args.lam, args.use_adaptation) diff --git a/research/cifar10/adaptive_pfl/mrmtl/client.py b/research/cifar10/adaptive_pfl/mrmtl/client.py index 69a5e430d..0cc5e1939 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/client.py +++ b/research/cifar10/adaptive_pfl/mrmtl/client.py @@ -63,7 +63,7 @@ def get_optimizer(self, config: Config) -> Optimizer: return torch.optim.SGD(self.model.parameters(), lr=self.learning_rate) def get_model(self, config: Config) -> nn.Module: - return ConvNet(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + return ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device) if __name__ == "__main__": @@ -128,7 +128,7 @@ def get_model(self, config: Config) -> nn.Module: log(INFO, f"Beta: {args.beta}") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Adding extensive checkpointing for the client checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) diff --git a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm index 2f708bce0..42056ec0f 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm @@ -62,6 +62,10 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi +# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# in fl4health/utils/random.py for more information +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + # Process Inputs SERVER_CONFIG_PATH=$1 diff --git a/research/cifar10/adaptive_pfl/mrmtl/server.py b/research/cifar10/adaptive_pfl/mrmtl/server.py index f3a5e8c60..6a3b57d57 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/server.py +++ b/research/cifar10/adaptive_pfl/mrmtl/server.py @@ -69,7 +69,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei client_manager = SimpleClientManager() # Initializing the model on the server side - model = ConvNet(in_channels=3, use_bn=False, dropout=0.1) + model = ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], @@ -142,6 +142,6 @@ def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_wei log(INFO, "Adapting the loss weight for model drift via global model loss") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) main(config, args.server_address, args.lam, args.use_adaptation) diff --git a/research/cifar10/fed_dgga_pfl/ditto/client.py b/research/cifar10/fed_dgga_pfl/ditto/client.py index 872e1f481..a2e510a7c 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/ditto/client.py @@ -65,7 +65,7 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: return {"global": global_optimizer, "local": local_optimizer} def get_model(self, config: Config) -> nn.Module: - return ConvNet(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + return ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device) if __name__ == "__main__": @@ -130,7 +130,7 @@ def get_model(self, config: Config) -> nn.Module: log(INFO, f"Beta: {args.beta}") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Adding extensive checkpointing for the client checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) diff --git a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm index 7d95fe7e4..e1e76f248 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm @@ -61,6 +61,10 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi +# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# in fl4health/utils/random.py for more information +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + # Process Inputs SERVER_CONFIG_PATH=$1 diff --git a/research/cifar10/fed_dgga_pfl/ditto/server.py b/research/cifar10/fed_dgga_pfl/ditto/server.py index e95b8543c..4bb2d5a6d 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/server.py +++ b/research/cifar10/fed_dgga_pfl/ditto/server.py @@ -56,7 +56,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, step_size: flo # is done right before fit_round. client_manager = FixedSamplingClientManager() # Initializing the model on the server side - model = ConvNet(in_channels=3, use_bn=False, dropout=0.1) + model = ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512) # Define a fairness metric based on the loss associated with the global Ditto model as that is the one being # aggregated by the server. @@ -135,6 +135,6 @@ def main(config: Dict[str, Any], server_address: str, lam: float, step_size: flo log(INFO, f"Step Size: {args.step_size}") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) main(config, args.server_address, args.lam, args.step_size) diff --git a/research/cifar10/fed_dgga_pfl/fenda/client.py b/research/cifar10/fed_dgga_pfl/fenda/client.py index 1d1abd03e..7c8828465 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda/client.py @@ -63,7 +63,7 @@ def get_optimizer(self, config: Config) -> Optimizer: return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) def get_model(self, config: Config) -> FendaModel: - return ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + return ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device) if __name__ == "__main__": @@ -128,7 +128,7 @@ def get_model(self, config: Config) -> FendaModel: log(INFO, f"Beta: {args.beta}") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Adding extensive checkpointing for the client checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) diff --git a/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm index 9ca21399c..afd97e40b 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm @@ -60,6 +60,10 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi +# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# in fl4health/utils/random.py for more information +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + # Process Inputs SERVER_CONFIG_PATH=$1 diff --git a/research/cifar10/fed_dgga_pfl/fenda/server.py b/research/cifar10/fed_dgga_pfl/fenda/server.py index b9408af6d..a880df617 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/server.py +++ b/research/cifar10/fed_dgga_pfl/fenda/server.py @@ -55,7 +55,7 @@ def main(config: Dict[str, Any], server_address: str, step_size: float) -> None: # is done right before fit_round. client_manager = FixedSamplingClientManager() # Initializing the model on the server side - model = ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1) + model = ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1, hidden=512) # Define a fairness metric based on the loss associated with the whole FENDA model fenda_fairness_metric = FairnessMetric(FairnessMetricType.LOSS) @@ -128,6 +128,6 @@ def main(config: Dict[str, Any], server_address: str, step_size: float) -> None: log(INFO, f"Step Size: {args.step_size}") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) main(config, args.server_address, args.step_size) diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py index 298525e71..792178240 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py @@ -68,10 +68,10 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: return {"global": global_optimizer, "local": local_optimizer} def get_model(self, config: Config) -> FendaModel: - return ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + return ConvNetFendaModel(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device) def get_global_model(self, config: Config) -> SequentiallySplitModel: - return ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1).to(self.device) + return ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1, hidden=512).to(self.device) if __name__ == "__main__": @@ -144,7 +144,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: log(INFO, "Freezing the global feature extractor of the FENDA model") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) # Adding extensive checkpointing for the client checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm index 265cf4eee..185032b6f 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm @@ -64,6 +64,10 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi +# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# in fl4health/utils/random.py for more information +export CUBLAS_WORKSPACE_CONFIG=:4096:8 + # Process Inputs SERVER_CONFIG_PATH=$1 diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py index 3a3452cbd..2104c559a 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py @@ -56,7 +56,7 @@ def main(config: Dict[str, Any], server_address: str, lam: float, step_size: flo # is done right before fit_round. client_manager = FixedSamplingClientManager() # Initializing the model on the server side - model = ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1) + model = ConvNetFendaDittoGlobalModel(in_channels=3, use_bn=False, dropout=0.1, hidden=512) # Define a fairness metric based on the loss associated with the global Ditto model as that is the one being # aggregated by the server. @@ -135,6 +135,6 @@ def main(config: Dict[str, Any], server_address: str, lam: float, step_size: flo log(INFO, f"Step Size: {args.step_size}") # Set the random seed for reproducibility - set_all_random_seeds(args.seed) + set_all_random_seeds(args.seed, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) main(config, args.server_address, args.lam, args.step_size)