From e657116499837d007364277448d8a49f69eb18df Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:48:53 -0500 Subject: [PATCH 1/3] Fixing some run issues with the examples, turning off failures for all servers, dropping some print statements --- .../ae_examples/cvae_dim_example/client.py | 4 ++-- .../ae_examples/cvae_dim_example/server.py | 1 + .../cvae_examples/conv_cvae_example/client.py | 4 ++-- .../cvae_examples/conv_cvae_example/server.py | 1 + .../cvae_examples/mlp_cvae_example/client.py | 4 ++-- .../cvae_examples/mlp_cvae_example/server.py | 1 + .../ae_examples/fedprox_vae_example/client.py | 4 ++-- .../fedprox_vae_example/config.yaml | 2 +- .../ae_examples/fedprox_vae_example/server.py | 3 ++- examples/apfl_example/server.py | 8 +++++++- .../best_checkpoint_fczjmljm.pkl | Bin .../best_checkpoint_fdctxbts.pkl | Bin .../assets}/model_merge_example/0.pt | Bin .../assets}/model_merge_example/1.pt | Bin examples/basic_example/server.py | 1 + examples/ditto_example/server.py | 2 +- .../dp_fed_examples/client_level_dp/server.py | 1 + .../client_level_dp_weighted/server.py | 1 + .../instance_level_dp/server.py | 1 + examples/dp_scaffold_example/server.py | 1 + .../dynamic_layer_exchange_example/server.py | 3 ++- examples/ensemble_example/server.py | 1 + examples/feature_alignment_example/server.py | 1 + examples/fedbn_example/server.py | 2 +- examples/feddg_ga_example/server.py | 8 +++++++- examples/federated_eval_example/server.py | 1 + examples/fedopt_example/server.py | 3 ++- .../fedpca_examples/dim_reduction/client.py | 3 ++- .../fedpca_examples/dim_reduction/server.py | 1 + examples/fedper_example/server.py | 2 +- examples/fedpm_example/server.py | 1 + examples/fedprox_example/server.py | 4 +++- examples/fedrep_example/server.py | 2 +- .../fedsimclr_finetuning_example/client.py | 3 ++- .../fedsimclr_finetuning_example/server.py | 1 + .../fedsimclr_pretraining_example/client.py | 1 - .../fedsimclr_pretraining_example/server.py | 1 + examples/fenda_ditto_example/server.py | 2 +- examples/fenda_example/server.py | 2 +- examples/fl_plus_local_ft_example/server.py | 2 +- examples/flash_example/server.py | 2 +- examples/model_merge_example/README.md | 7 +++++-- examples/moon_example/server.py | 2 +- examples/mr_mtl_example/server.py | 2 +- examples/nnunet_example/server.py | 1 + examples/perfcl_example/server.py | 2 +- examples/scaffold_example/server.py | 1 + .../server.py | 3 ++- .../warm_up_example/fedavg_warm_up/README.md | 4 ---- .../warm_up_example/fedavg_warm_up/client.py | 16 ++++++++++++++-- .../fedavg_warm_up/config.yaml | 8 -------- .../warm_up_example/fedavg_warm_up/server.py | 18 +----------------- .../warmed_up_fedprox/README.md | 6 +----- .../warmed_up_fedprox/client.py | 13 +++++-------- .../warmed_up_fedprox/config.yaml | 8 -------- .../warmed_up_fedprox/server.py | 18 +----------------- .../warm_up_example/warmed_up_fenda/README.md | 2 +- .../warm_up_example/warmed_up_fenda/client.py | 12 +++++------- .../warmed_up_fenda/config.yaml | 8 -------- .../warm_up_example/warmed_up_fenda/server.py | 17 +---------------- fl4health/preprocessing/warmed_up_module.py | 2 +- fl4health/strategies/basic_fedavg.py | 4 ++++ fl4health/strategies/client_dp_fedavgm.py | 4 ++++ fl4health/strategies/scaffold.py | 2 ++ fl4health/utils/dataset.py | 2 +- tests/clients/test_evaluate_client.py | 1 - tests/losses/test_deep_mmd_loss.py | 2 -- .../test_layer_exchanger.py | 1 - .../smoke_tests/feature_alignment_config.yaml | 11 ----------- tests/smoke_tests/run_smoke_test.py | 4 ++-- tests/utils/sampler_test.py | 4 ---- 71 files changed, 114 insertions(+), 156 deletions(-) rename examples/assets/{ => fed_eval_example}/best_checkpoint_fczjmljm.pkl (100%) rename examples/assets/{ => fed_eval_example}/best_checkpoint_fdctxbts.pkl (100%) rename {assets/checkpoints_for_examples => examples/assets}/model_merge_example/0.pt (100%) rename {assets/checkpoints_for_examples => examples/assets}/model_merge_example/1.pt (100%) delete mode 100644 tests/smoke_tests/feature_alignment_config.yaml diff --git a/examples/ae_examples/cvae_dim_example/client.py b/examples/ae_examples/cvae_dim_example/client.py index a6d88071e..6b14545f5 100644 --- a/examples/ae_examples/cvae_dim_example/client.py +++ b/examples/ae_examples/cvae_dim_example/client.py @@ -15,7 +15,7 @@ from fl4health.clients.basic_client import BasicClient from fl4health.preprocessing.autoencoders.dim_reduction import CvaeFixedConditionProcessor from fl4health.utils.config import narrow_dict_type -from fl4health.utils.load_data import load_mnist_data +from fl4health.utils.load_data import ToNumpy, load_mnist_data from fl4health.utils.metrics import Accuracy, Metric from fl4health.utils.random import set_all_random_seeds from fl4health.utils.sampler import DirichletLabelBasedSampler @@ -30,7 +30,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) cvae_model_path = Path(narrow_dict_type(config, "cvae_model_path", str)) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100) - transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)]) + transform = transforms.Compose([ToNumpy(), transforms.ToTensor(), transforms.Lambda(torch.flatten)]) # CvaeFixedConditionProcessor is added to the data transform pipeline to encode the data samples train_loader, val_loader, _ = load_mnist_data( data_dir=self.data_path, diff --git a/examples/ae_examples/cvae_dim_example/server.py b/examples/ae_examples/cvae_dim_example/server.py index 8e41136cf..650d53d09 100644 --- a/examples/ae_examples/cvae_dim_example/server.py +++ b/examples/ae_examples/cvae_dim_example/server.py @@ -72,6 +72,7 @@ def main(config: dict[str, Any]) -> None: fl_config=config, strategy=strategy, checkpoint_and_state_module=checkpoint_and_state_module, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/ae_examples/cvae_examples/conv_cvae_example/client.py b/examples/ae_examples/cvae_examples/conv_cvae_example/client.py index 9969f88d2..41e6cd347 100644 --- a/examples/ae_examples/cvae_examples/conv_cvae_example/client.py +++ b/examples/ae_examples/cvae_examples/conv_cvae_example/client.py @@ -17,7 +17,7 @@ from fl4health.preprocessing.autoencoders.loss import VaeLoss from fl4health.utils.config import narrow_dict_type from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter -from fl4health.utils.load_data import load_mnist_data +from fl4health.utils.load_data import ToNumpy, load_mnist_data from fl4health.utils.metrics import Metric from fl4health.utils.random import set_all_random_seeds from fl4health.utils.sampler import DirichletLabelBasedSampler @@ -60,7 +60,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100) # To make sure pixels stay in the range [0.0, 1.0]. - transform = transforms.Compose([transforms.ToTensor()]) + transform = transforms.Compose([ToNumpy(), transforms.ToTensor()]) # To train an autoencoder-based model we need to set the data converter. train_loader, val_loader, _ = load_mnist_data( data_dir=self.data_path, diff --git a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py index 0e6c5c16e..6c585b7b6 100644 --- a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py @@ -73,6 +73,7 @@ def main(config: dict[str, Any]) -> None: fl_config=config, strategy=strategy, checkpoint_and_state_module=checkpoint_and_state_module, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py b/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py index 23cf2ee40..545d17836 100644 --- a/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py +++ b/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py @@ -17,7 +17,7 @@ from fl4health.preprocessing.autoencoders.loss import VaeLoss from fl4health.utils.config import narrow_dict_type from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter -from fl4health.utils.load_data import load_mnist_data +from fl4health.utils.load_data import ToNumpy, load_mnist_data from fl4health.utils.metrics import Metric from fl4health.utils.random import set_all_random_seeds from fl4health.utils.sampler import DirichletLabelBasedSampler @@ -50,7 +50,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: # ToTensor transform is used to make sure pixels stay in the range [0.0, 1.0]. # Flattening the image data to match the input shape of the model. flatten_transform = transforms.Lambda(lambda x: torch.flatten(x)) - transform = transforms.Compose([transforms.ToTensor(), flatten_transform]) + transform = transforms.Compose([ToNumpy(), transforms.ToTensor(), flatten_transform]) train_loader, val_loader, _ = load_mnist_data( data_dir=self.data_path, batch_size=batch_size, diff --git a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py index 4389c73d1..ff66af3cb 100644 --- a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py @@ -73,6 +73,7 @@ def main(config: dict[str, Any]) -> None: fl_config=config, strategy=strategy, checkpoint_and_state_module=checkpoint_and_state_module, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/ae_examples/fedprox_vae_example/client.py b/examples/ae_examples/fedprox_vae_example/client.py index 066031621..dee44a819 100644 --- a/examples/ae_examples/fedprox_vae_example/client.py +++ b/examples/ae_examples/fedprox_vae_example/client.py @@ -16,7 +16,7 @@ from fl4health.preprocessing.autoencoders.loss import VaeLoss from fl4health.utils.config import narrow_dict_type from fl4health.utils.dataset_converter import AutoEncoderDatasetConverter -from fl4health.utils.load_data import load_mnist_data +from fl4health.utils.load_data import ToNumpy, load_mnist_data from fl4health.utils.sampler import DirichletLabelBasedSampler @@ -25,7 +25,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100) # Flattening the input images to use an MLP-based variational autoencoder. - transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)]) + transform = transforms.Compose([ToNumpy(), transforms.ToTensor(), transforms.Lambda(torch.flatten)]) # Create and pass the autoencoder data converter to the data loader. self.autoencoder_converter = AutoEncoderDatasetConverter(condition=None) train_loader, val_loader, _ = load_mnist_data( diff --git a/examples/ae_examples/fedprox_vae_example/config.yaml b/examples/ae_examples/fedprox_vae_example/config.yaml index eb9cdca93..3e687d3c3 100644 --- a/examples/ae_examples/fedprox_vae_example/config.yaml +++ b/examples/ae_examples/fedprox_vae_example/config.yaml @@ -8,7 +8,7 @@ batch_size: 32 # The batch size for client training # FedProx variables adaptive_proximal_weight: False # Whether to use adaptive proximal weight or not -proximal_weight : 0.1 # The proximal weight +initial_proximal_weight : 0.1 # The proximal weight # Checkpointing checkpoint_path: "examples/ae_examples/fedprox_vae_example" diff --git a/examples/ae_examples/fedprox_vae_example/server.py b/examples/ae_examples/fedprox_vae_example/server.py index ee5ca375b..a8ba5afff 100644 --- a/examples/ae_examples/fedprox_vae_example/server.py +++ b/examples/ae_examples/fedprox_vae_example/server.py @@ -66,7 +66,7 @@ def main(config: dict[str, Any]) -> None: fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(model), - adapt_loss_weight=config["adapt_proximal_weight"], + adapt_loss_weight=config["adaptive_proximal_weight"], initial_loss_weight=config["initial_proximal_weight"], ) @@ -75,6 +75,7 @@ def main(config: dict[str, Any]) -> None: fl_config=config, strategy=strategy, checkpoint_and_state_module=checkpoint_and_state_module, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/apfl_example/server.py b/examples/apfl_example/server.py index 2aa950a71..2b367eba2 100644 --- a/examples/apfl_example/server.py +++ b/examples/apfl_example/server.py @@ -60,7 +60,13 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=[JsonReporter()]) + server = FlServer( + client_manager=client_manager, + fl_config=config, + strategy=strategy, + reporters=[JsonReporter()], + accept_failures=False, + ) fl.server.start_server( server=server, diff --git a/examples/assets/best_checkpoint_fczjmljm.pkl b/examples/assets/fed_eval_example/best_checkpoint_fczjmljm.pkl similarity index 100% rename from examples/assets/best_checkpoint_fczjmljm.pkl rename to examples/assets/fed_eval_example/best_checkpoint_fczjmljm.pkl diff --git a/examples/assets/best_checkpoint_fdctxbts.pkl b/examples/assets/fed_eval_example/best_checkpoint_fdctxbts.pkl similarity index 100% rename from examples/assets/best_checkpoint_fdctxbts.pkl rename to examples/assets/fed_eval_example/best_checkpoint_fdctxbts.pkl diff --git a/assets/checkpoints_for_examples/model_merge_example/0.pt b/examples/assets/model_merge_example/0.pt similarity index 100% rename from assets/checkpoints_for_examples/model_merge_example/0.pt rename to examples/assets/model_merge_example/0.pt diff --git a/assets/checkpoints_for_examples/model_merge_example/1.pt b/examples/assets/model_merge_example/1.pt similarity index 100% rename from assets/checkpoints_for_examples/model_merge_example/1.pt rename to examples/assets/model_merge_example/1.pt diff --git a/examples/basic_example/server.py b/examples/basic_example/server.py index 2472daed6..d1e383249 100644 --- a/examples/basic_example/server.py +++ b/examples/basic_example/server.py @@ -71,6 +71,7 @@ def main(config: dict[str, Any]) -> None: fl_config=config, strategy=strategy, checkpoint_and_state_module=checkpoint_and_state_module, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/ditto_example/server.py b/examples/ditto_example/server.py index 0bf485ff3..6f0f43416 100644 --- a/examples/ditto_example/server.py +++ b/examples/ditto_example/server.py @@ -62,7 +62,7 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = DittoServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = DittoServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/dp_fed_examples/client_level_dp/server.py b/examples/dp_fed_examples/client_level_dp/server.py index 3fb2ca448..ce6ad66b0 100644 --- a/examples/dp_fed_examples/client_level_dp/server.py +++ b/examples/dp_fed_examples/client_level_dp/server.py @@ -84,6 +84,7 @@ def main(config: dict[str, Any]) -> None: strategy=strategy, server_noise_multiplier=config["server_noise_multiplier"], num_server_rounds=config["n_server_rounds"], + accept_failures=False, ) fl.server.start_server( diff --git a/examples/dp_fed_examples/client_level_dp_weighted/server.py b/examples/dp_fed_examples/client_level_dp_weighted/server.py index 32236a635..fa22f81b8 100644 --- a/examples/dp_fed_examples/client_level_dp_weighted/server.py +++ b/examples/dp_fed_examples/client_level_dp_weighted/server.py @@ -80,6 +80,7 @@ def main(config: dict[str, Any]) -> None: clipping_noise_multiplier=config["clipping_bit_noise_multiplier"], beta=config["server_momentum"], weighted_aggregation=config["weighted_averaging"], + accept_failures=False, ) server = ClientLevelDPFedAvgServer( diff --git a/examples/dp_fed_examples/instance_level_dp/server.py b/examples/dp_fed_examples/instance_level_dp/server.py index a7068ffdb..13f97276f 100644 --- a/examples/dp_fed_examples/instance_level_dp/server.py +++ b/examples/dp_fed_examples/instance_level_dp/server.py @@ -106,6 +106,7 @@ def main(config: dict[str, Any]) -> None: batch_size=config["batch_size"], num_server_rounds=config["n_server_rounds"], checkpoint_and_state_module=checkpoint_and_state_module, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/dp_scaffold_example/server.py b/examples/dp_scaffold_example/server.py index 95565832c..7737429b1 100644 --- a/examples/dp_scaffold_example/server.py +++ b/examples/dp_scaffold_example/server.py @@ -68,6 +68,7 @@ def main(config: dict[str, Any]) -> None: num_server_rounds=config["n_server_rounds"], strategy=strategy, warm_start=True, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/dynamic_layer_exchange_example/server.py b/examples/dynamic_layer_exchange_example/server.py index eff4b2c1b..30eb30514 100644 --- a/examples/dynamic_layer_exchange_example/server.py +++ b/examples/dynamic_layer_exchange_example/server.py @@ -74,10 +74,11 @@ def main(config: dict[str, Any]) -> None: fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(model), + accept_failures=False, ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/ensemble_example/server.py b/examples/ensemble_example/server.py index c3a8c623a..c52da9b8f 100644 --- a/examples/ensemble_example/server.py +++ b/examples/ensemble_example/server.py @@ -59,6 +59,7 @@ def main(config: dict[str, Any]) -> None: fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(initial_model), + accept_failures=False, ) fl.server.start_server( diff --git a/examples/feature_alignment_example/server.py b/examples/feature_alignment_example/server.py index 6756bcbe0..fd45f14ca 100644 --- a/examples/feature_alignment_example/server.py +++ b/examples/feature_alignment_example/server.py @@ -62,6 +62,7 @@ def main(config: dict[str, Any]) -> None: initialize_parameters=get_initial_model_parameters, strategy=strategy, tabular_features_source_of_truth=tab_feature_info_encoder_hospital1, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/fedbn_example/server.py b/examples/fedbn_example/server.py index 46a9aad01..efea03913 100644 --- a/examples/fedbn_example/server.py +++ b/examples/fedbn_example/server.py @@ -70,7 +70,7 @@ def main(config: dict[str, Any], server_address: str, dataset_name: str) -> None ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/feddg_ga_example/server.py b/examples/feddg_ga_example/server.py index 7d7c2775b..5abb2b6d5 100644 --- a/examples/feddg_ga_example/server.py +++ b/examples/feddg_ga_example/server.py @@ -70,7 +70,13 @@ def main(config: dict[str, Any]) -> None: # will return the same sampling until it is told to reset, which in FedDgGaStrategy # is done right before fit_round. client_manager = FixedSamplingClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=[JsonReporter()]) + server = FlServer( + client_manager=client_manager, + fl_config=config, + strategy=strategy, + reporters=[JsonReporter()], + accept_failures=False, + ) fl.server.start_server( server=server, diff --git a/examples/federated_eval_example/server.py b/examples/federated_eval_example/server.py index 9423513dc..20b75c060 100644 --- a/examples/federated_eval_example/server.py +++ b/examples/federated_eval_example/server.py @@ -23,6 +23,7 @@ def main(config: dict[str, Any], server_checkpoint_path: Path | None) -> None: evaluate_config=evaluate_config, evaluate_metrics_aggregation_fn=uniform_evaluate_metrics_aggregation_fn, min_available_clients=config["n_clients"], + accept_failures=False, ) fl.server.start_server( diff --git a/examples/fedopt_example/server.py b/examples/fedopt_example/server.py index ba91b197f..8d0f75d2e 100644 --- a/examples/fedopt_example/server.py +++ b/examples/fedopt_example/server.py @@ -147,10 +147,11 @@ def main(config: dict[str, Any]) -> None: on_evaluate_config_fn=fit_config_fn, # Server side weight initialization initial_parameters=get_all_model_parameters(initial_model), + accept_failures=False, ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server_address=config["server_address"], diff --git a/examples/fedpca_examples/dim_reduction/client.py b/examples/fedpca_examples/dim_reduction/client.py index 09c2ce00d..11c44928d 100644 --- a/examples/fedpca_examples/dim_reduction/client.py +++ b/examples/fedpca_examples/dim_reduction/client.py @@ -14,7 +14,7 @@ from fl4health.clients.basic_client import BasicClient from fl4health.preprocessing.pca_preprocessor import PcaPreprocessor from fl4health.utils.config import narrow_dict_type -from fl4health.utils.load_data import get_train_and_val_mnist_datasets +from fl4health.utils.load_data import ToNumpy, get_train_and_val_mnist_datasets from fl4health.utils.metrics import Accuracy from fl4health.utils.random import set_all_random_seeds from fl4health.utils.sampler import DirichletLabelBasedSampler @@ -31,6 +31,7 @@ def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: # Get training and validation datasets. transform = transforms.Compose( [ + ToNumpy(), transforms.ToTensor(), transforms.Normalize((0.5), (0.5)), ] diff --git a/examples/fedpca_examples/dim_reduction/server.py b/examples/fedpca_examples/dim_reduction/server.py index 10bfc5e4b..7e43bb3df 100644 --- a/examples/fedpca_examples/dim_reduction/server.py +++ b/examples/fedpca_examples/dim_reduction/server.py @@ -72,6 +72,7 @@ def main(config: dict[str, Any]) -> None: fl_config=config, strategy=strategy, checkpoint_and_state_module=checkpoint_and_state_module, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/fedper_example/server.py b/examples/fedper_example/server.py index 8a952055a..c073c4804 100644 --- a/examples/fedper_example/server.py +++ b/examples/fedper_example/server.py @@ -68,7 +68,7 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/fedpm_example/server.py b/examples/fedpm_example/server.py index dd3629519..fcedd729f 100644 --- a/examples/fedpm_example/server.py +++ b/examples/fedpm_example/server.py @@ -65,6 +65,7 @@ def main(config: dict[str, Any]) -> None: initial_parameters=get_all_model_parameters(initial_model), # Perform Bayesian aggregation. bayesian_aggregation=True, + accept_failures=False, ) client_manager = SimpleClientManager() diff --git a/examples/fedprox_example/server.py b/examples/fedprox_example/server.py index 274a6ffe5..aa676691b 100644 --- a/examples/fedprox_example/server.py +++ b/examples/fedprox_example/server.py @@ -81,7 +81,9 @@ def main(config: dict[str, Any], server_address: str) -> None: reporters = [wandb_reporter, json_reporter] else: reporters = [json_reporter] - server = FedProxServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters) + server = FedProxServer( + client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters, accept_failures=False + ) fl.server.start_server( server=server, diff --git a/examples/fedrep_example/server.py b/examples/fedrep_example/server.py index be2754782..3c4179f32 100644 --- a/examples/fedrep_example/server.py +++ b/examples/fedrep_example/server.py @@ -70,7 +70,7 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py b/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py index fb61a40ed..48808ff67 100644 --- a/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py +++ b/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py @@ -15,7 +15,7 @@ from fl4health.model_bases.fedsimclr_base import FedSimClrModel from fl4health.utils.config import narrow_dict_type from fl4health.utils.dataset import TensorDataset -from fl4health.utils.load_data import get_cifar10_data_and_target_tensors, split_data_and_targets +from fl4health.utils.load_data import ToNumpy, get_cifar10_data_and_target_tensors, split_data_and_targets from fl4health.utils.metrics import Accuracy @@ -26,6 +26,7 @@ def get_finetune_dataset(data_dir: Path, batch_size: int) -> tuple[DataLoader, D input_transform = transforms.Compose( [ + ToNumpy(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] diff --git a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py index f87060b61..13a577d20 100644 --- a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py @@ -76,6 +76,7 @@ def main(config: dict[str, Any]) -> None: fl_config=config, strategy=strategy, checkpoint_and_state_module=checkpoint_and_state_module, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py b/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py index efa0bf5fa..d981262f6 100644 --- a/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py +++ b/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py @@ -34,7 +34,6 @@ def get_transforms() -> tuple[Callable, Callable]: target_transform = transforms.Compose( [ - ToNumpy(), transforms.ToPILImage(), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply([color_jitter], p=0.8), diff --git a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py index 0821936fa..1292e7462 100644 --- a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py @@ -75,6 +75,7 @@ def main(config: dict[str, Any]) -> None: fl_config=config, strategy=strategy, checkpoint_and_state_module=checkpoint_and_state_module, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/fenda_ditto_example/server.py b/examples/fenda_ditto_example/server.py index f68871f24..ceeec21b9 100644 --- a/examples/fenda_ditto_example/server.py +++ b/examples/fenda_ditto_example/server.py @@ -68,7 +68,7 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/fenda_example/server.py b/examples/fenda_example/server.py index 64c837bd3..42ec547a6 100644 --- a/examples/fenda_example/server.py +++ b/examples/fenda_example/server.py @@ -64,7 +64,7 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/fl_plus_local_ft_example/server.py b/examples/fl_plus_local_ft_example/server.py index dd13a0d64..bb3f081d5 100644 --- a/examples/fl_plus_local_ft_example/server.py +++ b/examples/fl_plus_local_ft_example/server.py @@ -54,7 +54,7 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server_address="0.0.0.0:8080", diff --git a/examples/flash_example/server.py b/examples/flash_example/server.py index 535c73bb1..d4889d1e8 100644 --- a/examples/flash_example/server.py +++ b/examples/flash_example/server.py @@ -45,7 +45,7 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/model_merge_example/README.md b/examples/model_merge_example/README.md index 0814c5746..684663c28 100644 --- a/examples/model_merge_example/README.md +++ b/examples/model_merge_example/README.md @@ -4,7 +4,7 @@ a copy of the same architecture with different weights initialized via local pre average these weights and perform evaluation on the client side and the server side with the provided evaluation function. The server expects two clients to be spun up (i.e. it will wait until two clients report in before starting model merging and evaluation). For convenience, pre-trained models on the MNIST -train set have been provided for each of the clients in `assets/checkpoints_for_examples/model_merge_example` +train set have been provided for each of the clients in `/examples/assets/model_merge_example/` under `0.pt` and `1.pt`. The model merging and subsequent evaluation can be performed with these weights out-of-the-box. @@ -15,8 +15,11 @@ In order to run the example, first ensure you have [installed the dependencies i The next step is to start the server by running: ``` -python -m examples.model_merge_example.server +python -m examples.model_merge_example.server --config_path /path/to/config ``` +Optionally, you can provide a path to an evaluation dataset (`--data_path`) to evaluate the merged models on the +server side. + For a full list of arguments and their definitions: `python -m examples.model_merge_example.server --help` ## Starting Clients diff --git a/examples/moon_example/server.py b/examples/moon_example/server.py index 21437650f..fb322027f 100644 --- a/examples/moon_example/server.py +++ b/examples/moon_example/server.py @@ -60,7 +60,7 @@ def main(config: dict[str, Any]) -> None: initial_parameters=get_all_model_parameters(initial_model), ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/mr_mtl_example/server.py b/examples/mr_mtl_example/server.py index c6933bc66..524cc48e6 100644 --- a/examples/mr_mtl_example/server.py +++ b/examples/mr_mtl_example/server.py @@ -63,7 +63,7 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = MrMtlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = MrMtlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/nnunet_example/server.py b/examples/nnunet_example/server.py index 0b38f7fa2..f80cfb127 100644 --- a/examples/nnunet_example/server.py +++ b/examples/nnunet_example/server.py @@ -109,6 +109,7 @@ def main( strategy=strategy, checkpoint_and_state_module=checkpoint_and_state_module, server_name=server_name, + accept_failures=False, ) fl.server.start_server( diff --git a/examples/perfcl_example/server.py b/examples/perfcl_example/server.py index 563c11332..752b8f6d8 100644 --- a/examples/perfcl_example/server.py +++ b/examples/perfcl_example/server.py @@ -64,7 +64,7 @@ def main(config: dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/scaffold_example/server.py b/examples/scaffold_example/server.py index 3512df2df..3debaa60a 100644 --- a/examples/scaffold_example/server.py +++ b/examples/scaffold_example/server.py @@ -58,6 +58,7 @@ def main(config: dict[str, Any]) -> None: strategy=strategy, warm_start=True, reporters=[JsonReporter()], + accept_failures=False, ) fl.server.start_server( server=server, diff --git a/examples/sparse_tensor_partial_exchange_example/server.py b/examples/sparse_tensor_partial_exchange_example/server.py index e41c6948b..3f1f9e4ea 100644 --- a/examples/sparse_tensor_partial_exchange_example/server.py +++ b/examples/sparse_tensor_partial_exchange_example/server.py @@ -56,10 +56,11 @@ def main(config: dict[str, Any]) -> None: fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(model), + accept_failures=False, ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/warm_up_example/fedavg_warm_up/README.md b/examples/warm_up_example/fedavg_warm_up/README.md index ce1b387a4..8516981b3 100644 --- a/examples/warm_up_example/fedavg_warm_up/README.md +++ b/examples/warm_up_example/fedavg_warm_up/README.md @@ -6,10 +6,6 @@ The server has some custom metrics aggregation and uses Federated Averaging as i As this is a warm-up training for consecutive runs with different Federated Learning (FL) algorithms, it is crucial to set a fixed seed for both clients and the server to ensure uniformity in random data points across these runs. Therefore, we make sure to set a fixed seed for these consecutive runs in both the `client.py` and `server.py` files. Additionally, it is important to establish a checkpointing strategy for the clients using their randomly generated unique client names. This allows us to load each client's warmed-up model from this example in further instances. In this particular scenario, we set the checkpointing strategy to save the latest model. This ensures that we can load the trained local model for each client from this example in subsequent runs as a warmed-up model. -### Weights and Biases Reporting - -This example is also capable of logging results to your Weights and Biases account by setting `enabled` to `True` in the `config.yaml` under the `reporting_config` section. You'll also need to set the `entity` value to your Weights and Biases entity. Once those two things are set, you should be able to run the example and log the results to W and B directly. - ### Running the Example In order to run the example, first ensure you have [installed the dependencies in your virtual environment according to the main README](/README.md#development-requirements) and it has been activated. diff --git a/examples/warm_up_example/fedavg_warm_up/client.py b/examples/warm_up_example/fedavg_warm_up/client.py index 3948a5398..0eb3aba1e 100644 --- a/examples/warm_up_example/fedavg_warm_up/client.py +++ b/examples/warm_up_example/fedavg_warm_up/client.py @@ -30,9 +30,11 @@ def __init__( metrics: Sequence[Metric], device: torch.device, checkpoint_dir: str, + client_name: str, ) -> None: + # Checkpointing is crucial for the warm up process - checkpoint_name = f"client_{self.client_name}_latest_model.pkl" + checkpoint_name = f"client_{client_name}_latest_model.pkl" post_aggregation_checkpointer = LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) @@ -41,6 +43,7 @@ def __init__( metrics=metrics, device=device, checkpoint_and_state_module=checkpoint_and_state_module, + client_name=client_name, ) def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: @@ -83,6 +86,13 @@ def get_criterion(self, config: Config) -> _Loss: help="Path to the directory where the checkpoints are stored", required=True, ) + parser.add_argument( + "--client_name", + action="store", + type=str, + help="Name for the client, this will also be used to set the checkpoint name", + required=True, + ) args = parser.parse_args() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -94,7 +104,9 @@ def get_criterion(self, config: Config) -> _Loss: set_all_random_seeds(args.seed) # Start the client - client = MnistFedAvgClient(data_path, [Accuracy()], device, checkpoint_dir=args.checkpoint_dir) + client = MnistFedAvgClient( + data_path, [Accuracy()], device, checkpoint_dir=args.checkpoint_dir, client_name=args.client_name + ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) # Shutdown the client gracefully diff --git a/examples/warm_up_example/fedavg_warm_up/config.yaml b/examples/warm_up_example/fedavg_warm_up/config.yaml index 057c0e29a..4ad40dfb4 100644 --- a/examples/warm_up_example/fedavg_warm_up/config.yaml +++ b/examples/warm_up_example/fedavg_warm_up/config.yaml @@ -5,11 +5,3 @@ n_server_rounds: 2 # The number of rounds to run FL n_clients: 3 # The number of clients in the FL experiment local_epochs: 1 # The number of epochs to complete for client batch_size: 128 # The batch size for client training - -reporting_config: - project: FL4Health # Name of the project under which everything should be logged - name: "FedAvg Server" # Name of the run on the server-side, each client will also have it's own run name - group: "FedAvg Experiment" # Group under which each of the FL run logging will be stored - entity: "your_entity_here" # WandB user name - notes: "Testing WB reporting" - tags: ["Test", "FedAvg"] diff --git a/examples/warm_up_example/fedavg_warm_up/server.py b/examples/warm_up_example/fedavg_warm_up/server.py index 38574b617..f7e41ee80 100644 --- a/examples/warm_up_example/fedavg_warm_up/server.py +++ b/examples/warm_up_example/fedavg_warm_up/server.py @@ -10,7 +10,6 @@ from examples.models.cnn_model import MnistNet from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.reporting import WandBReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.basic_fedavg import BasicFedAvg from fl4health.utils.config import load_config @@ -22,9 +21,6 @@ def fit_config( batch_size: int, n_server_rounds: int, - project: str, - group: str, - entity: str, current_round: int, local_epochs: int | None = None, local_steps: int | None = None, @@ -34,9 +30,6 @@ def fit_config( "batch_size": batch_size, "n_server_rounds": n_server_rounds, "current_server_round": current_round, - "project": project, - "group": group, - "entity": entity, } @@ -46,10 +39,6 @@ def main(config: dict[str, Any], server_address: str) -> None: fit_config, config["batch_size"], config["n_server_rounds"], - # NOTE: that name is not included, it will be set in the clients - config["reporting_config"].get("project", ""), - config["reporting_config"].get("group", ""), - config["reporting_config"].get("entity", ""), local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), ) @@ -71,13 +60,8 @@ def main(config: dict[str, Any], server_address: str) -> None: ) client_manager = SimpleClientManager() - if "reporting_config" in config: - wandb_reporter = WandBReporter("round", **config["reporting_config"]) - reporters = [wandb_reporter] - else: - reporters = [] - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/warm_up_example/warmed_up_fedprox/README.md b/examples/warm_up_example/warmed_up_fedprox/README.md index a90efb8fc..8a7f740de 100644 --- a/examples/warm_up_example/warmed_up_fedprox/README.md +++ b/examples/warm_up_example/warmed_up_fedprox/README.md @@ -6,10 +6,6 @@ The server has some custom metrics aggregation and uses FedProx as its server-si After the warm-up training, clients can load their warmed-up models and continue training with the FedProx algorithm. To maintain consistency in the data loader between both runs, it is crucial to set a fixed seed for both clients and the server, ensuring uniformity in random data points across consecutive runs. Therefore, we ensure a fixed seed is set for these consecutive runs in both the `client.py` and `server.py` files. Additionally, to load the warmed-up models, it's important provide the path to the pretrained models based on client's unique name, ensuring that we can load the trained local model for each client from the previous example as a warmed-up model. Since models in the two runs can be different, loading weights from the pretrained model requires providing a mapping between the pretrained model and the model used in FL training. This mapping is accomplished through the `weights_mapping.json` file, which contains the names of the pretrained model's layers and the corresponding names of the layers in the model used in FL training. -### Weights and Biases Reporting - -This example is also capable of logging results to your Weights and Biases account by setting `enabled` to `True` in the `config.yaml` under the `reporting_config` section. You'll also need to set the `entity` value to your Weights and Biases entity. Once those two things are set, you should be able to run the example and log the results to W and B directly. - ### Running the Example In order to run the example, first ensure you have [installed the dependencies in your virtual environment according to the main README](/README.md#development-requirements) and it has been activated. @@ -31,7 +27,7 @@ from the FL4Health directory. The following arguments must be present in the spe Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the three clients. This is done by simply running (remembering to activate your environment) ``` -python -m examples.warm_up_example.warmed_up_fedprox.client --dataset_path /path/to/data --seed "SEED" --pretrained_model_dir /path/to/checkpointing/directory --weights_mapping_path /path/to/weights/mapping/file +python -m examples.warm_up_example.warmed_up_fedprox.client --dataset_path /path/to/data --seed "SEED" --pretrained_model_path /path/to/model_checkpoint --weights_mapping_path /path/to/weights/mapping/file ``` **NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be automatically downloaded to the path specified and used in the run. diff --git a/examples/warm_up_example/warmed_up_fedprox/client.py b/examples/warm_up_example/warmed_up_fedprox/client.py index 1dddf1e00..1a77fc548 100644 --- a/examples/warm_up_example/warmed_up_fedprox/client.py +++ b/examples/warm_up_example/warmed_up_fedprox/client.py @@ -1,5 +1,4 @@ import argparse -import os from collections.abc import Sequence from logging import INFO from pathlib import Path @@ -29,7 +28,7 @@ def __init__( data_path: Path, metrics: Sequence[Metric], device: torch.device, - pretrained_model_dir: Path, + pretrained_model_path: Path, weights_mapping_path: Path | None, ) -> None: super().__init__( @@ -37,11 +36,9 @@ def __init__( metrics=metrics, device=device, ) - # Load the warmed up module - pretrained_model_name = f"client_{self.client_name}_latest_model.pkl" self.warmed_up_module = WarmedUpModule( - pretrained_model_path=Path(os.path.join(pretrained_model_dir, pretrained_model_name)), + pretrained_model_path=pretrained_model_path, weights_mapping_path=weights_mapping_path, ) @@ -85,7 +82,7 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> required=False, ) parser.add_argument( - "--pretrained_model_dir", + "--pretrained_model_path", action="store", type=str, help="Path to the pretrained model", @@ -102,7 +99,7 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - pretrained_model_dir = Path(args.pretrained_model_dir) + pretrained_model_path = Path(args.pretrained_model_path) weights_mapping_path = Path(args.weights_mapping_path) if args.weights_mapping_path else None log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") @@ -115,7 +112,7 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> data_path, [Accuracy()], device, - pretrained_model_dir, + pretrained_model_path, weights_mapping_path, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/examples/warm_up_example/warmed_up_fedprox/config.yaml b/examples/warm_up_example/warmed_up_fedprox/config.yaml index 420590b72..c15b6700d 100644 --- a/examples/warm_up_example/warmed_up_fedprox/config.yaml +++ b/examples/warm_up_example/warmed_up_fedprox/config.yaml @@ -14,11 +14,3 @@ proximal_weight_patience : 5 # The number of rounds to wait before increasing or n_clients: 3 # The number of clients in the FL experiment local_epochs: 1 # The number of epochs to complete for client batch_size: 128 # The batch size for client training - -reporting_config: - project: FL4Health # Name of the project under which everything should be logged - name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name - group: "FedProx Experiment" # Group under which each of the FL run logging will be stored - entity: "your_entity_here" # WandB user name - notes: "Testing WB reporting" - tags: ["Test", "FedProx"] diff --git a/examples/warm_up_example/warmed_up_fedprox/server.py b/examples/warm_up_example/warmed_up_fedprox/server.py index 98e3c56f1..133cf3e91 100644 --- a/examples/warm_up_example/warmed_up_fedprox/server.py +++ b/examples/warm_up_example/warmed_up_fedprox/server.py @@ -10,7 +10,6 @@ from examples.models.cnn_model import MnistNet from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.reporting import WandBReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -22,9 +21,6 @@ def fit_config( batch_size: int, n_server_rounds: int, - project: str, - group: str, - entity: str, current_round: int, local_epochs: int | None = None, local_steps: int | None = None, @@ -34,10 +30,6 @@ def fit_config( "batch_size": batch_size, "n_server_rounds": n_server_rounds, "current_server_round": current_round, - "project": project, - "group": group, - "entity": entity, - "entity": entity, } @@ -48,9 +40,6 @@ def main(config: dict[str, Any], server_address: str) -> None: config["batch_size"], config["n_server_rounds"], # NOTE: that name is not included, it will be set in the clients - config["reporting_config"].get("project", ""), - config["reporting_config"].get("group", ""), - config["reporting_config"].get("entity", ""), local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), ) @@ -76,12 +65,7 @@ def main(config: dict[str, Any], server_address: str) -> None: ) client_manager = SimpleClientManager() - if "reporting_config" in config: - wandb_reporter = WandBReporter("round", **config["reporting_config"]) - reporters = [wandb_reporter] - else: - reporters = [] - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/examples/warm_up_example/warmed_up_fenda/README.md b/examples/warm_up_example/warmed_up_fenda/README.md index 2c61dfece..c643c0989 100644 --- a/examples/warm_up_example/warmed_up_fenda/README.md +++ b/examples/warm_up_example/warmed_up_fenda/README.md @@ -25,7 +25,7 @@ from the FL4Health directory. The following arguments must be present in the spe Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the three clients. This is done by simply running (remembering to activate your environment) ``` -python -m examples.warm_up_example.warmed_up_fenda.client --dataset_path /path/to/data --seed "SEED" --pretrained_model_dir /path/to/checkpointing/directory --weights_mapping_path /path/to/weights/mapping/file +python -m examples.warm_up_example.warmed_up_fenda.client --dataset_path /path/to/data --seed "SEED" --pretrained_model_path /path/to/model_checkpoint --weights_mapping_path /path/to/weights/mapping/file ``` **NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be automatically downloaded to the path specified and used in the run. diff --git a/examples/warm_up_example/warmed_up_fenda/client.py b/examples/warm_up_example/warmed_up_fenda/client.py index b57974e9b..869545ad9 100644 --- a/examples/warm_up_example/warmed_up_fenda/client.py +++ b/examples/warm_up_example/warmed_up_fenda/client.py @@ -1,5 +1,4 @@ import argparse -import os from collections.abc import Sequence from logging import INFO from pathlib import Path @@ -31,7 +30,7 @@ def __init__( data_path: Path, metrics: Sequence[Metric], device: torch.device, - pretrained_model_dir: Path, + pretrained_model_path: Path, weights_mapping_path: Path | None, ) -> None: super().__init__( @@ -41,9 +40,8 @@ def __init__( ) # Load the warmed up module - pretrained_model_name = f"client_{self.client_name}_latest_model.pkl" self.warmed_up_module = WarmedUpModule( - pretrained_model_path=Path(os.path.join(pretrained_model_dir, pretrained_model_name)), + pretrained_model_path=pretrained_model_path, weights_mapping_path=weights_mapping_path, ) @@ -88,7 +86,7 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> required=False, ) parser.add_argument( - "--pretrained_model_dir", + "--pretrained_model_path", action="store", type=str, help="Path to the pretrained model", @@ -105,7 +103,7 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = Path(args.dataset_path) - pretrained_model_dir = Path(args.pretrained_model_dir) + pretrained_model_path = Path(args.pretrained_model_path) weights_mapping_path = Path(args.weights_mapping_path) if args.weights_mapping_path else None log(INFO, f"Device to be used: {device}") log(INFO, f"Server Address: {args.server_address}") @@ -118,7 +116,7 @@ def initialize_all_model_weights(self, parameters: NDArrays, config: Config) -> data_path, [Accuracy("accuracy")], device, - pretrained_model_dir, + pretrained_model_path, weights_mapping_path, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/examples/warm_up_example/warmed_up_fenda/config.yaml b/examples/warm_up_example/warmed_up_fenda/config.yaml index 2a2162164..3e218515f 100644 --- a/examples/warm_up_example/warmed_up_fenda/config.yaml +++ b/examples/warm_up_example/warmed_up_fenda/config.yaml @@ -5,11 +5,3 @@ n_server_rounds: 3 # The number of rounds to run FL n_clients: 3 # The number of clients in the FL experiment local_epochs: 1 # The number of epochs to complete for client batch_size: 128 # The batch size for client training - -reporting_config: - project: FL4Health # Name of the project under which everything should be logged - name: "Fenda Server" # Name of the run on the server-side, each client will also have it's own run name - group: "Fenda Experiment" # Group under which each of the FL run logging will be stored - entity: "your_entity_here" # WandB user name - notes: "Testing WB reporting" - tags: ["Test", "Fenda"] diff --git a/examples/warm_up_example/warmed_up_fenda/server.py b/examples/warm_up_example/warmed_up_fenda/server.py index a622175ce..8d1bb8469 100644 --- a/examples/warm_up_example/warmed_up_fenda/server.py +++ b/examples/warm_up_example/warmed_up_fenda/server.py @@ -13,7 +13,6 @@ from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.parallel_split_models import ParallelFeatureJoinMode -from fl4health.reporting import WandBReporter from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn @@ -24,9 +23,6 @@ def fit_config( batch_size: int, n_server_rounds: int, - project: str, - group: str, - entity: str, current_round: int, local_epochs: int | None = None, local_steps: int | None = None, @@ -36,9 +32,6 @@ def fit_config( "batch_size": batch_size, "n_server_rounds": n_server_rounds, "current_server_round": current_round, - "project": project, - "group": group, - "entity": entity, } @@ -49,9 +42,6 @@ def main(config: dict[str, Any], server_address: str) -> None: config["batch_size"], config["n_server_rounds"], # NOTE: that name is not included, it will be set in the clients - config["reporting_config"].get("project", ""), - config["reporting_config"].get("group", ""), - config["reporting_config"].get("entity", ""), local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), ) @@ -75,12 +65,7 @@ def main(config: dict[str, Any], server_address: str) -> None: ) client_manager = SimpleClientManager() - if "reporting_config" in config: - wandb_reporter = WandBReporter("round", **config["reporting_config"]) - reporters = [wandb_reporter] - else: - reporters = [] - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy, accept_failures=False) fl.server.start_server( server=server, diff --git a/fl4health/preprocessing/warmed_up_module.py b/fl4health/preprocessing/warmed_up_module.py index 4b33254da..e8e498181 100644 --- a/fl4health/preprocessing/warmed_up_module.py +++ b/fl4health/preprocessing/warmed_up_module.py @@ -118,7 +118,7 @@ def load_from_pretrained(self, model: torch.nn.Module) -> torch.nn.Module: f"state won't be loaded. Key {pretrained_key} not found in the pretrained model states.", ) - log(INFO, f"{len(matching_state)}/{len(target_model_state)} states got matched.") + log(INFO, f"{len(matching_state)}/{len(target_model_state)} states were matched.") target_model_state.update(matching_state) model.load_state_dict(target_model_state) diff --git a/fl4health/strategies/basic_fedavg.py b/fl4health/strategies/basic_fedavg.py index 7317fe34d..60c480be5 100644 --- a/fl4health/strategies/basic_fedavg.py +++ b/fl4health/strategies/basic_fedavg.py @@ -132,6 +132,8 @@ def configure_fit( if self.on_fit_config_fn is not None: # Custom fit config function provided config = self.on_fit_config_fn(server_round) + else: + config = {"current_server_round": server_round} fit_ins = FitIns(parameters, config) # Sample clients @@ -174,6 +176,8 @@ def configure_evaluate( if self.on_evaluate_config_fn is not None: # Custom evaluation config function provided config = self.on_evaluate_config_fn(server_round) + else: + config = {"current_server_round": server_round} evaluate_ins = EvaluateIns(parameters, config) # Sample clients diff --git a/fl4health/strategies/client_dp_fedavgm.py b/fl4health/strategies/client_dp_fedavgm.py index 5c3428865..ef6d9d31b 100644 --- a/fl4health/strategies/client_dp_fedavgm.py +++ b/fl4health/strategies/client_dp_fedavgm.py @@ -384,6 +384,8 @@ def configure_fit( if self.on_fit_config_fn is not None: # Custom fit config function provided config = self.on_fit_config_fn(server_round) + else: + config = {"current_server_round": server_round} fit_ins = FitIns(parameters, config) @@ -426,6 +428,8 @@ def configure_evaluate( if self.on_evaluate_config_fn is not None: # Custom evaluation config function provided config = self.on_evaluate_config_fn(server_round) + else: + config = {"current_server_round": server_round} evaluate_ins = EvaluateIns(parameters, config) # Sample clients diff --git a/fl4health/strategies/scaffold.py b/fl4health/strategies/scaffold.py index bff2accd2..f6a119a78 100644 --- a/fl4health/strategies/scaffold.py +++ b/fl4health/strategies/scaffold.py @@ -290,6 +290,8 @@ def configure_fit_all( if self.on_fit_config_fn is not None: # Custom fit config function provided config = self.on_fit_config_fn(server_round) + else: + config = {"current_server_round": server_round} fit_ins = FitIns(parameters, config) diff --git a/fl4health/utils/dataset.py b/fl4health/utils/dataset.py index 256795013..86fd42d4e 100644 --- a/fl4health/utils/dataset.py +++ b/fl4health/utils/dataset.py @@ -72,7 +72,7 @@ def __init__( transform: Callable | None = None, target_transform: Callable | None = None, ) -> None: - assert targets is not None, "SslTensorDataset targets must be None" + assert targets is None, "SslTensorDataset targets must be None" super().__init__(data, targets, transform, target_transform) diff --git a/tests/clients/test_evaluate_client.py b/tests/clients/test_evaluate_client.py index 768648758..65488703a 100644 --- a/tests/clients/test_evaluate_client.py +++ b/tests/clients/test_evaluate_client.py @@ -115,7 +115,6 @@ def test_metrics_reporter_evaluate() -> None: reporter = JsonReporter() evaluate_client = MockEvaluateClient(loss=test_loss, metrics=test_metrics, reporters=[reporter]) evaluate_client.evaluate([], {}) - print(reporter.metrics) metric_dict = { "host_type": "client", "initialized": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), diff --git a/tests/losses/test_deep_mmd_loss.py b/tests/losses/test_deep_mmd_loss.py index 61c599619..36edead13 100644 --- a/tests/losses/test_deep_mmd_loss.py +++ b/tests/losses/test_deep_mmd_loss.py @@ -51,8 +51,6 @@ def test_forward() -> None: output = deep_mmd_loss_1(X, Y) val_outputs_1.append(output) - print(train_outputs_1[0].item()) - # The output of the DeepMmdLoss in training mode should be different for each optimization step # as values are updated in each step assert pytest.approx(train_outputs_1[0].item(), abs=0.001) == 0.0573 diff --git a/tests/parameter_exchange/test_layer_exchanger.py b/tests/parameter_exchange/test_layer_exchanger.py index db0c83b05..ea240456c 100644 --- a/tests/parameter_exchange/test_layer_exchanger.py +++ b/tests/parameter_exchange/test_layer_exchanger.py @@ -99,7 +99,6 @@ def test_fedpm_exchange() -> None: # Test that selection function works when the direct child modules are masked modules. masks, score_names = select_scores_and_sample_masks(masked_model, masked_model) assert len(masks) == len(score_names) - print(score_names) assert score_names == [ "conv1d.weight_scores", "conv1d.bias_scores", diff --git a/tests/smoke_tests/feature_alignment_config.yaml b/tests/smoke_tests/feature_alignment_config.yaml deleted file mode 100644 index 346d658a7..000000000 --- a/tests/smoke_tests/feature_alignment_config.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# Parameters that describe server -n_server_rounds: 3 # The number of rounds to run FL - -# Parameters that describe clients -n_clients: 2 # The number of clients in the FL experiment -local_steps: 5 # The number of local steps (one per batch) to complete for client -batch_size: 64 # The batch size for client training - -source_specified: False # Specifies whether the server knows the source -# of truth for performing feature alignment a priori. If it is False, -# then the server will randomly poll a client to obtain this information. diff --git a/tests/smoke_tests/run_smoke_test.py b/tests/smoke_tests/run_smoke_test.py index 9a2f1810f..73525cfa3 100644 --- a/tests/smoke_tests/run_smoke_test.py +++ b/tests/smoke_tests/run_smoke_test.py @@ -74,7 +74,7 @@ async def run_smoke_test( client_python_path="examples.federated_eval_example.client", config_path="tests/smoke_tests/federated_eval_config.yaml", dataset_path="examples/datasets/cifar_data/", - checkpoint_path="examples/assets/best_checkpoint_fczjmljm.pkl", + checkpoint_path="examples/assets/fed_eval_example/best_checkpoint_fczjmljm.pkl", assert_evaluation_logs=True, seed=42, server_metrics={ @@ -756,7 +756,7 @@ def load_metrics_from_file(file_path: str) -> dict[str, Any]: client_python_path="examples.federated_eval_example.client", config_path="tests/smoke_tests/federated_eval_config.yaml", dataset_path="examples/datasets/cifar_data/", - checkpoint_path="examples/assets/best_checkpoint_fczjmljm.pkl", + checkpoint_path="examples/assets/fed_eval_example/best_checkpoint_fczjmljm.pkl", assert_evaluation_logs=True, ) ) diff --git a/tests/utils/sampler_test.py b/tests/utils/sampler_test.py index eb531c2be..44ed0b3ef 100644 --- a/tests/utils/sampler_test.py +++ b/tests/utils/sampler_test.py @@ -249,16 +249,12 @@ def test_dirichlet_sampler_with_hash_key() -> None: train_probs = np.array([i / sum(samples_per_class) for i in samples_per_class]) test_probs = np.array([i / sum(test_samples_per_class) for i in test_samples_per_class]) - print(train_probs) - print(test_probs) # Assert that the original train and test distributions are same # atol is set to 1e-2 because there might be some rounding noise due to set fixed number of samples assert np.allclose(train_probs, test_probs, rtol=0.0, atol=1e-2) new_train_probs = np.array([i / sum(new_samples_per_class_1) for i in new_samples_per_class_1]) new_test_probs = np.array([i / sum(new_test_samples_per_class_1) for i in new_test_samples_per_class_1]) - print(new_train_probs) - print(new_test_probs) # Assert that the new train and test distributions with sampler_1 are same due to same hash_key # atol is set to 1e-2 because there might be some rounding noise due to set fixed number of samples assert np.allclose(new_train_probs, new_test_probs, rtol=0.0, atol=1e-2) From 20b37053630f11c6456fd1b4a993bcd787c789fd Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:55:01 -0500 Subject: [PATCH 2/3] Small reversal of a change --- examples/dynamic_layer_exchange_example/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dynamic_layer_exchange_example/server.py b/examples/dynamic_layer_exchange_example/server.py index 30eb30514..0a59bba16 100644 --- a/examples/dynamic_layer_exchange_example/server.py +++ b/examples/dynamic_layer_exchange_example/server.py @@ -74,7 +74,6 @@ def main(config: dict[str, Any]) -> None: fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(model), - accept_failures=False, ) client_manager = SimpleClientManager() From c42882b0773e0202b7e650d97c386a12a67f9ec8 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Fri, 29 Nov 2024 17:02:03 -0500 Subject: [PATCH 3/3] Small fix --- examples/ensemble_example/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/ensemble_example/server.py b/examples/ensemble_example/server.py index c52da9b8f..c3a8c623a 100644 --- a/examples/ensemble_example/server.py +++ b/examples/ensemble_example/server.py @@ -59,7 +59,6 @@ def main(config: dict[str, Any]) -> None: fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(initial_model), - accept_failures=False, ) fl.server.start_server(