-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FedPer Implementation #72
Changes from all commits
698f9d6
685dce5
0471048
186b788
fc07bbf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# FedPer Federated Learning Example | ||
This example provides an example of training a FedPer type model on a non-IID subset of the MNIST data. The FL server | ||
expects three clients to be spun up (i.e. it will wait until three clients report in before starting training). Each client | ||
has a modified version of the MNIST dataset. This modification essentially subsamples a certain number from the original | ||
training and validation sets of MNIST in order to synthetically induce local variations in the statistical properties | ||
of the clients training/validation data. In theory, the models should be able to perform well on their local data | ||
while learning from other clients data that has different statistical properties. The subsampling is specified by | ||
sending a list of integers between 0-9 to the clients when they are run with the argument `--minority_numbers`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Client in this example does not have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The client here: |
||
|
||
The server has some custom metrics aggregation and uses Federated Averaging as its server-side optimization. The implementation uses a special type of weight exchange based on named-layer identification. | ||
|
||
## Running the Example | ||
In order to run the example, first ensure you have the virtual env of your choice activated and run | ||
``` | ||
pip install --upgrade pip | ||
pip install -r requirements.txt | ||
``` | ||
to install all of the dependencies for this project. | ||
|
||
## Starting Server | ||
|
||
The next step is to start the server by running | ||
``` | ||
python -m examples.fedper_example.server --config_path /path/to/config.yaml | ||
``` | ||
from the FL4Health directory. The following arguments must be present in the specified config file: | ||
* `n_clients`: number of clients the server waits for in order to run the FL training | ||
* `local_epochs`: number of epochs each client will train for locally | ||
* `batch_size`: size of the batches each client will train on | ||
* `n_server_rounds`: The number of rounds to run FL | ||
* `downsampling_ratio`: The amount of downsampling to perform for minority digits | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, Config does not have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the config does have this argument, but based on the |
||
|
||
## Starting Clients | ||
|
||
Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the three | ||
clients. This is done by simply running (remembering to activate your environment) | ||
``` | ||
python -m examples.fedper_example.client --dataset_path /path/to/data --minority_numbers <sequence of numbers> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe because of same issue running this command would also get an error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See the comments above. I think we're okay here. I did run the example with these args and it worked appropriately. |
||
``` | ||
**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If | ||
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be | ||
automatically downloaded to the path specified and used in the run. | ||
|
||
The argument `minority_numbers` specifies which digits (0-9) in the MNIST dataset the client will subsample to | ||
simulate non-IID data between clients. For example `--minority_numbers 1 2 3 4 5` will ensure that the client | ||
downsamples these digits (using the `downsampling_ratio` specified to the config). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This section also should be omitted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the above, I think we're good to leave this in? |
||
|
||
After both clients have been started federated learning should commence. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import argparse | ||
from pathlib import Path | ||
from typing import Sequence, Set, Tuple | ||
|
||
import flwr as fl | ||
import torch | ||
import torch.nn as nn | ||
from flwr.common.typing import Config | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
||
from examples.models.fedper_cnn import FedPerGloalFeatureExtractor, FedPerLocalPredictionHead | ||
from fl4health.clients.moon_client import MoonClient | ||
from fl4health.model_bases.fedper_base import FedPerModel | ||
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger | ||
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger | ||
from fl4health.utils.load_data import load_mnist_data | ||
from fl4health.utils.metrics import Accuracy, Metric | ||
from fl4health.utils.sampler import MinorityLabelBasedSampler | ||
|
||
|
||
class MnistFedPerClient(MoonClient): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: We inherit from a MOON client here intentionally to be able to use auxiliary losses associated with the global module's feature space in addition to the personalized architecture. |
||
def __init__( | ||
self, | ||
data_path: Path, | ||
metrics: Sequence[Metric], | ||
device: torch.device, | ||
minority_numbers: Set[int], | ||
) -> None: | ||
# We inherit from a MOON client here intentionally to be able to use auxiliary losses associated with the | ||
# global module's feature space in addition to the personalized architecture of FedPer. | ||
super().__init__(data_path=data_path, metrics=metrics, device=device) | ||
self.minority_numbers = minority_numbers | ||
|
||
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: | ||
batch_size = self.narrow_config_type(config, "batch_size", int) | ||
downsample_percentage = self.narrow_config_type(config, "downsampling_ratio", float) | ||
sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers) | ||
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) | ||
return train_loader, val_loader | ||
|
||
def get_model(self, config: Config) -> nn.Module: | ||
# NOTE: Flatten features is set to true to make the model compatible with the MOON contrastive loss function, | ||
# which requires the intermediate feature representations to be flattened for similarity calculations. | ||
model: nn.Module = FedPerModel( | ||
global_feature_extractor=FedPerGloalFeatureExtractor(), | ||
local_prediction_head=FedPerLocalPredictionHead(), | ||
flatten_features=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flatten features is used to make the model compatible with MOON which requires the intermediate feature representations to be flattened for similarity calculations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you should write a comment regarding that? Here or in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, definitely. Thanks for pointing that out. |
||
).to(self.device) | ||
return model | ||
|
||
def get_optimizer(self, config: Config) -> Optimizer: | ||
return torch.optim.AdamW(self.model.parameters(), lr=0.001) | ||
|
||
def get_criterion(self, config: Config) -> _Loss: | ||
return torch.nn.CrossEntropyLoss() | ||
|
||
def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: | ||
assert isinstance(self.model, FedPerModel) | ||
return FixedLayerExchanger(self.model.layers_to_exchange()) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="FL Client Main") | ||
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") | ||
parser.add_argument( | ||
"--minority_numbers", default=[], nargs="*", help="MNIST numbers to be in the minority for the current client" | ||
) | ||
args = parser.parse_args() | ||
|
||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
data_path = Path(args.dataset_path) | ||
minority_numbers = {int(number) for number in args.minority_numbers} | ||
client = MnistFedPerClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers) | ||
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=client) | ||
client.shutdown() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Parameters that describe server | ||
n_server_rounds: 3 # The number of rounds to run FL | ||
|
||
# Parameters that describe clients | ||
n_clients: 3 # The number of clients in the FL experiment | ||
local_epochs: 1 # The number of epochs to complete for client | ||
batch_size: 32 # The batch size for client training | ||
|
||
# Downsampling settings per client | ||
downsampling_ratio: 0.1 # percentage of original mnist data to keep for minority numbers |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import argparse | ||
from functools import partial | ||
from typing import Any, Dict | ||
|
||
import flwr as fl | ||
from flwr.common.parameter import ndarrays_to_parameters | ||
from flwr.common.typing import Config, Parameters | ||
from flwr.server.strategy import FedAvg | ||
|
||
from examples.models.fedper_cnn import FedPerGloalFeatureExtractor, FedPerLocalPredictionHead | ||
from examples.simple_metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn | ||
from fl4health.model_bases.fedper_base import FedPerModel | ||
from fl4health.utils.config import load_config | ||
|
||
|
||
def get_initial_model_parameters() -> Parameters: | ||
# Initializing the model parameters on the server side. | ||
# Currently uses the Pytorch default initialization for the model parameters. | ||
initial_model = FedPerModel( | ||
global_feature_extractor=FedPerGloalFeatureExtractor(), | ||
local_prediction_head=FedPerLocalPredictionHead(), | ||
flatten_features=True, | ||
) | ||
return ndarrays_to_parameters([val.cpu().numpy() for _, val in initial_model.state_dict().items()]) | ||
|
||
|
||
def fit_config( | ||
local_epochs: int, | ||
batch_size: int, | ||
n_server_rounds: int, | ||
downsampling_ratio: float, | ||
current_round: int, | ||
) -> Config: | ||
return { | ||
"local_epochs": local_epochs, | ||
"batch_size": batch_size, | ||
"n_server_rounds": n_server_rounds, | ||
"downsampling_ratio": downsampling_ratio, | ||
"current_server_round": current_round, | ||
} | ||
|
||
|
||
def main(config: Dict[str, Any]) -> None: | ||
# This function will be used to produce a config that is sent to each client to initialize their own environment | ||
fit_config_fn = partial( | ||
fit_config, | ||
config["local_epochs"], | ||
config["batch_size"], | ||
config["n_server_rounds"], | ||
config["downsampling_ratio"], | ||
) | ||
|
||
# Server performs simple FedAveraging as its server-side optimization strategy | ||
strategy = FedAvg( | ||
min_fit_clients=config["n_clients"], | ||
min_evaluate_clients=config["n_clients"], | ||
# Server waits for min_available_clients before starting FL rounds | ||
min_available_clients=config["n_clients"], | ||
on_fit_config_fn=fit_config_fn, | ||
# We use the same fit config function, as nothing changes for eval | ||
on_evaluate_config_fn=fit_config_fn, | ||
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, | ||
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, | ||
initial_parameters=get_initial_model_parameters(), | ||
) | ||
|
||
fl.server.start_server( | ||
server_address="0.0.0.0:8080", | ||
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), | ||
strategy=strategy, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="FL Server Main") | ||
parser.add_argument( | ||
"--config_path", | ||
action="store", | ||
type=str, | ||
help="Path to configuration file.", | ||
default="examples/fedper_example/config.yaml", | ||
) | ||
args = parser.parse_args() | ||
|
||
config = load_config(args.config_path) | ||
|
||
main(config) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ def __init__( | |
device: torch.device, | ||
minority_numbers: Set[int], | ||
) -> None: | ||
super().__init__(data_path=data_path, metrics=metrics, device=device) | ||
super().__init__(data_path=data_path, metrics=metrics, device=device, perfcl_loss_weights=(1.0, 1.0)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding in the perfcl loss to the FENDA example. This is just for testing when running the example to make sure nothing is broken there. |
||
self.minority_numbers = minority_numbers | ||
|
||
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class FedPerLocalPredictionHead(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.fc1 = nn.Linear(120, 84) | ||
self.fc2 = nn.Linear(84, 10) | ||
|
||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: | ||
x = F.relu(self.fc1(input_tensor)) | ||
x = self.fc2(x) | ||
return x | ||
|
||
|
||
class FedPerGloalFeatureExtractor(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 4 * 4, 120) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 4 * 4) | ||
x = F.relu(self.fc1(x)) | ||
return x |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,7 +95,6 @@ def predict(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[st | |
return preds, features | ||
|
||
def get_parameters(self, config: Config) -> NDArrays: | ||
|
||
# Save the parameters of the old model | ||
assert isinstance(self.model, FendaModel) | ||
if self.contrastive_loss_weight or self.perfcl_loss_weights: | ||
|
@@ -105,7 +104,6 @@ def get_parameters(self, config: Config) -> NDArrays: | |
return super().get_parameters(config) | ||
|
||
def set_parameters(self, parameters: NDArrays, config: Config) -> None: | ||
|
||
# Set the parameters of the model | ||
super().set_parameters(parameters, config) | ||
|
||
|
@@ -207,7 +205,7 @@ def compute_loss( | |
""" | ||
|
||
loss = self.criterion(preds["prediction"], target) | ||
total_loss = loss | ||
total_loss = loss.clone() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without clone, |
||
additional_losses = {} | ||
|
||
# Optimal cos_sim_loss_weight for FedIsic dataset is 100.0 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,11 +60,12 @@ def predict(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[st | |
the old model are returned. All predictions included in dictionary will be used to compute metrics. | ||
""" | ||
preds, features = self.model(input) | ||
old_features = torch.zeros(self.len_old_models_buffer, *features.size()).to(self.device) | ||
old_features = torch.zeros(self.len_old_models_buffer, *features["features"].size()).to(self.device) | ||
for i, old_model in enumerate(self.old_models_list): | ||
old_features[i] = old_model(input)[1] | ||
global_features = self.global_model(input)[1] | ||
features.update({"global_features": global_features, "old_features": old_features}) | ||
_, old_model_features = old_model(input) | ||
old_features[i] = old_model_features["features"] | ||
_, global_model_features = self.global_model(input) | ||
features.update({"global_features": global_model_features["features"], "old_features": old_features}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the refactor of the use of a features dictionary, this predict function ended up being broken. I believe these changes fix the issue. However, correct me if I'm wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is true; I have fixed it adding contrastive losses PR and it seems correct on the version of main that I have. However, it might be changed in further merges. Thanks for pointing out. |
||
return preds, features | ||
|
||
def get_contrastive_loss( | ||
|
@@ -90,7 +91,6 @@ def get_contrastive_loss( | |
|
||
def set_parameters(self, parameters: NDArrays, config: Config) -> None: | ||
assert isinstance(self.model, MoonModel) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unnecessary, unless I'm mistaken. Removing it also allows for FedPer models to be used with MOON Clients, which is nice... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it seems like mypy doesn't have problem with omitting it (it was originally for that). However if we want to build up FedPer over MOON, why don't we inherit it from the MOON model? This would make it easier for users to understand that FedPer models can be used with MOON Clients. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a fair question. The main reason I didn't do that was because FedPer model's exchange a partial subset of weights and don't, at least by default, admit projection modules for their features. They are very related. So it's possible that unifying them is a good idea. What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. projection_module is also optional in moon_base, so you can easily pass by None for it in inheritance, and everything should work well. I kinda prefer fedper_base to inherit from both moon and partial_layer_exchange_model so users can get the relationship between all of them and the added functionality of fed_per based on both. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think that makes sense. I'll do that right now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added the assert back in now that FedPer inherits from MOON |
||
|
||
# Save the parameters of the old local model | ||
old_model = self.clone_and_freeze_model(self.model) | ||
self.old_models_list.append(old_model) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Number of clients are two in the config, probably we need to update config or readme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm looking in the wrong place. but the config here
examples/fedper_example/config.yaml
hasn_clients: 3
. So I think the readme is accurate?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh true I somehow got confused with feature alignment example and thought this is added readme for it.