Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FedPer Implementation #72

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
48 changes: 48 additions & 0 deletions examples/fedper_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# FedPer Federated Learning Example
This example provides an example of training a FedPer type model on a non-IID subset of the MNIST data. The FL server
expects three clients to be spun up (i.e. it will wait until three clients report in before starting training). Each client
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Number of clients are two in the config, probably we need to update config or readme.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm looking in the wrong place. but the config here examples/fedper_example/config.yaml has n_clients: 3. So I think the readme is accurate?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh true I somehow got confused with feature alignment example and thought this is added readme for it.

has a modified version of the MNIST dataset. This modification essentially subsamples a certain number from the original
training and validation sets of MNIST in order to synthetically induce local variations in the statistical properties
of the clients training/validation data. In theory, the models should be able to perform well on their local data
while learning from other clients data that has different statistical properties. The subsampling is specified by
sending a list of integers between 0-9 to the clients when they are run with the argument `--minority_numbers`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Client in this example does not have --minority_numbers argument and it does not use MinorityLabelBasedSampler. I think it assigns random train-valid split.

Copy link
Collaborator Author

@emersodb emersodb Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client here: examples/fedper_example/client.py does take the --minority_numbers and uses MinorityLabelBasedSampler in the get_data_loaders function. Please correct me if I'm wrong though


The server has some custom metrics aggregation and uses Federated Averaging as its server-side optimization. The implementation uses a special type of weight exchange based on named-layer identification.

## Running the Example
In order to run the example, first ensure you have the virtual env of your choice activated and run
```
pip install --upgrade pip
pip install -r requirements.txt
```
to install all of the dependencies for this project.

## Starting Server

The next step is to start the server by running
```
python -m examples.fedper_example.server --config_path /path/to/config.yaml
```
from the FL4Health directory. The following arguments must be present in the specified config file:
* `n_clients`: number of clients the server waits for in order to run the FL training
* `local_epochs`: number of epochs each client will train for locally
* `batch_size`: size of the batches each client will train on
* `n_server_rounds`: The number of rounds to run FL
* `downsampling_ratio`: The amount of downsampling to perform for minority digits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, Config does not have a downsampling_ratio variable but has an extra source_specified variable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the config does have this argument, but based on the source_specified comment, I think you might be looking at the wrong config, as that sounds like the config for the feature_alignment_example?


## Starting Clients

Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the three
clients. This is done by simply running (remembering to activate your environment)
```
python -m examples.fedper_example.client --dataset_path /path/to/data --minority_numbers <sequence of numbers>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe because of same issue running this command would also get an error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comments above. I think we're okay here. I did run the example with these args and it worked appropriately.

```
**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be
automatically downloaded to the path specified and used in the run.

The argument `minority_numbers` specifies which digits (0-9) in the MNIST dataset the client will subsample to
simulate non-IID data between clients. For example `--minority_numbers 1 2 3 4 5` will ensure that the client
downsamples these digits (using the `downsampling_ratio` specified to the config).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section also should be omitted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the above, I think we're good to leave this in?


After both clients have been started federated learning should commence.
Empty file.
77 changes: 77 additions & 0 deletions examples/fedper_example/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse
from pathlib import Path
from typing import Sequence, Set, Tuple

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from examples.models.fedper_cnn import FedPerGloalFeatureExtractor, FedPerLocalPredictionHead
from fl4health.clients.moon_client import MoonClient
from fl4health.model_bases.fedper_base import FedPerModel
from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.sampler import MinorityLabelBasedSampler


class MnistFedPerClient(MoonClient):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: We inherit from a MOON client here intentionally to be able to use auxiliary losses associated with the global module's feature space in addition to the personalized architecture.

def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
minority_numbers: Set[int],
) -> None:
# We inherit from a MOON client here intentionally to be able to use auxiliary losses associated with the
# global module's feature space in addition to the personalized architecture of FedPer.
super().__init__(data_path=data_path, metrics=metrics, device=device)
self.minority_numbers = minority_numbers

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = self.narrow_config_type(config, "batch_size", int)
downsample_percentage = self.narrow_config_type(config, "downsampling_ratio", float)
sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers)
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
return train_loader, val_loader

def get_model(self, config: Config) -> nn.Module:
# NOTE: Flatten features is set to true to make the model compatible with the MOON contrastive loss function,
# which requires the intermediate feature representations to be flattened for similarity calculations.
model: nn.Module = FedPerModel(
global_feature_extractor=FedPerGloalFeatureExtractor(),
local_prediction_head=FedPerLocalPredictionHead(),
flatten_features=True,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flatten features is used to make the model compatible with MOON which requires the intermediate feature representations to be flattened for similarity calculations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should write a comment regarding that? Here or in fedper_base.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely. Thanks for pointing that out.

).to(self.device)
return model

def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.AdamW(self.model.parameters(), lr=0.001)

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
assert isinstance(self.model, FedPerModel)
return FixedLayerExchanger(self.model.layers_to_exchange())


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FL Client Main")
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
parser.add_argument(
"--minority_numbers", default=[], nargs="*", help="MNIST numbers to be in the minority for the current client"
)
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
minority_numbers = {int(number) for number in args.minority_numbers}
client = MnistFedPerClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers)
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=client)
client.shutdown()
10 changes: 10 additions & 0 deletions examples/fedper_example/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Parameters that describe server
n_server_rounds: 3 # The number of rounds to run FL

# Parameters that describe clients
n_clients: 3 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client
batch_size: 32 # The batch size for client training

# Downsampling settings per client
downsampling_ratio: 0.1 # percentage of original mnist data to keep for minority numbers
87 changes: 87 additions & 0 deletions examples/fedper_example/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import argparse
from functools import partial
from typing import Any, Dict

import flwr as fl
from flwr.common.parameter import ndarrays_to_parameters
from flwr.common.typing import Config, Parameters
from flwr.server.strategy import FedAvg

from examples.models.fedper_cnn import FedPerGloalFeatureExtractor, FedPerLocalPredictionHead
from examples.simple_metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.model_bases.fedper_base import FedPerModel
from fl4health.utils.config import load_config


def get_initial_model_parameters() -> Parameters:
# Initializing the model parameters on the server side.
# Currently uses the Pytorch default initialization for the model parameters.
initial_model = FedPerModel(
global_feature_extractor=FedPerGloalFeatureExtractor(),
local_prediction_head=FedPerLocalPredictionHead(),
flatten_features=True,
)
return ndarrays_to_parameters([val.cpu().numpy() for _, val in initial_model.state_dict().items()])


def fit_config(
local_epochs: int,
batch_size: int,
n_server_rounds: int,
downsampling_ratio: float,
current_round: int,
) -> Config:
return {
"local_epochs": local_epochs,
"batch_size": batch_size,
"n_server_rounds": n_server_rounds,
"downsampling_ratio": downsampling_ratio,
"current_server_round": current_round,
}


def main(config: Dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
config["local_epochs"],
config["batch_size"],
config["n_server_rounds"],
config["downsampling_ratio"],
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
min_available_clients=config["n_clients"],
on_fit_config_fn=fit_config_fn,
# We use the same fit config function, as nothing changes for eval
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_initial_model_parameters(),
)

fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
strategy=strategy,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FL Server Main")
parser.add_argument(
"--config_path",
action="store",
type=str,
help="Path to configuration file.",
default="examples/fedper_example/config.yaml",
)
args = parser.parse_args()

config = load_config(args.config_path)

main(config)
2 changes: 1 addition & 1 deletion examples/fenda_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
device: torch.device,
minority_numbers: Set[int],
) -> None:
super().__init__(data_path=data_path, metrics=metrics, device=device)
super().__init__(data_path=data_path, metrics=metrics, device=device, perfcl_loss_weights=(1.0, 1.0))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding in the perfcl loss to the FENDA example. This is just for testing when running the example to make sure nothing is broken there.

self.minority_numbers = minority_numbers

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
Expand Down
31 changes: 31 additions & 0 deletions examples/models/fedper_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class FedPerLocalPredictionHead(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(84, 10)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
x = F.relu(self.fc1(input_tensor))
x = self.fc2(x)
return x


class FedPerGloalFeatureExtractor(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
return x
4 changes: 2 additions & 2 deletions fl4health/clients/apfl_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[Losses,
losses.backward.backward()
self.local_optimizer.step()

# Return dictionairy of predictions where key is used to name respective MetricMeters
# Return dictionary of predictions where key is used to name respective MetricMeters
return losses, preds

def get_parameter_exchanger(self, config: Config) -> FixedLayerExchanger:
Expand Down Expand Up @@ -108,6 +108,6 @@ def set_optimizer(self, config: Config) -> None:

def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
"""
Returns a dictionairy with global and local optimizers with string keys 'global' and 'local' respectively.
Returns a dictionary with global and local optimizers with string keys 'global' and 'local' respectively.
"""
raise NotImplementedError
4 changes: 2 additions & 2 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def compute_loss(
def set_optimizer(self, config: Config) -> None:
"""
Method called in the the setup_client method to set optimizer attribute returned by used-defined get_optimizer.
In the simplest case, get_optimizer returns an optimizer. For more advanced use cases where a dictionairy of
In the simplest case, get_optimizer returns an optimizer. For more advanced use cases where a dictionary of
string and optimizer are returned (ie APFL), the use must override this method.
"""
optimizer = self.get_optimizer(config)
Expand Down Expand Up @@ -419,7 +419,7 @@ def get_model(self, config: Config) -> nn.Module:
def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> None:
"""
Called after training with the number of local_steps performed over the FL round and
the corresponding loss dictionairy.
the corresponding loss dictionary.
"""
pass

Expand Down
2 changes: 1 addition & 1 deletion fl4health/clients/evaluate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def predict(self, input: torch.Tensor) -> torch.Tensor:

def compute_loss(self, preds: torch.Tensor, target: torch.Tensor) -> Losses:
"""
Computes loss given preds and torch and the user defined criterion. Optionally includes dictionairy of
Computes loss given preds and torch and the user defined criterion. Optionally includes dictionary of
loss components if you wish to train the total loss as well as sub losses if they exist.
"""
loss = self.criterion(preds, target)
Expand Down
2 changes: 1 addition & 1 deletion fl4health/clients/fed_prox_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> None:
"""
Called after training with the number of local_steps performed over the FL round and
the corresponding loss dictionairy.
the corresponding loss dictionary.
"""
# Store current loss which is the vanilla loss without the proximal term added in
self.current_loss = loss_dict["checkpoint"]
4 changes: 1 addition & 3 deletions fl4health/clients/fenda_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def predict(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[st
return preds, features

def get_parameters(self, config: Config) -> NDArrays:

# Save the parameters of the old model
assert isinstance(self.model, FendaModel)
if self.contrastive_loss_weight or self.perfcl_loss_weights:
Expand All @@ -105,7 +104,6 @@ def get_parameters(self, config: Config) -> NDArrays:
return super().get_parameters(config)

def set_parameters(self, parameters: NDArrays, config: Config) -> None:

# Set the parameters of the model
super().set_parameters(parameters, config)

Expand Down Expand Up @@ -207,7 +205,7 @@ def compute_loss(
"""

loss = self.criterion(preds["prediction"], target)
total_loss = loss
total_loss = loss.clone()
Copy link
Collaborator Author

@emersodb emersodb Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without clone, total_loss and loss share memory. This means that anything that is added to total_loss below is also added to loss. This means that checkpoint and backward in the loss object end up being identical, which we don't want. I added a unit test to make sure the clone here fixes the issue.

additional_losses = {}

# Optimal cos_sim_loss_weight for FedIsic dataset is 100.0
Expand Down
10 changes: 5 additions & 5 deletions fl4health/clients/moon_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ def predict(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[st
the old model are returned. All predictions included in dictionary will be used to compute metrics.
"""
preds, features = self.model(input)
old_features = torch.zeros(self.len_old_models_buffer, *features.size()).to(self.device)
old_features = torch.zeros(self.len_old_models_buffer, *features["features"].size()).to(self.device)
for i, old_model in enumerate(self.old_models_list):
old_features[i] = old_model(input)[1]
global_features = self.global_model(input)[1]
features.update({"global_features": global_features, "old_features": old_features})
_, old_model_features = old_model(input)
old_features[i] = old_model_features["features"]
_, global_model_features = self.global_model(input)
features.update({"global_features": global_model_features["features"], "old_features": old_features})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the refactor of the use of a features dictionary, this predict function ended up being broken. I believe these changes fix the issue. However, correct me if I'm wrong.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true; I have fixed it adding contrastive losses PR and it seems correct on the version of main that I have. However, it might be changed in further merges. Thanks for pointing out.

return preds, features

def get_contrastive_loss(
Expand All @@ -90,7 +91,6 @@ def get_contrastive_loss(

def set_parameters(self, parameters: NDArrays, config: Config) -> None:
assert isinstance(self.model, MoonModel)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary, unless I'm mistaken. Removing it also allows for FedPer models to be used with MOON Clients, which is nice...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it seems like mypy doesn't have problem with omitting it (it was originally for that). However if we want to build up FedPer over MOON, why don't we inherit it from the MOON model? This would make it easier for users to understand that FedPer models can be used with MOON Clients.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair question. The main reason I didn't do that was because FedPer model's exchange a partial subset of weights and don't, at least by default, admit projection modules for their features. They are very related. So it's possible that unifying them is a good idea. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

projection_module is also optional in moon_base, so you can easily pass by None for it in inheritance, and everything should work well. I kinda prefer fedper_base to inherit from both moon and partial_layer_exchange_model so users can get the relationship between all of them and the added functionality of fed_per based on both.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that makes sense. I'll do that right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the assert back in now that FedPer inherits from MOON


# Save the parameters of the old local model
old_model = self.clone_and_freeze_model(self.model)
self.old_models_list.append(old_model)
Expand Down
2 changes: 1 addition & 1 deletion fl4health/clients/scaffold_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
def update_after_train(self, local_steps: int, loss_dict: Dict[str, float]) -> None:
"""
Called after training with the number of local_steps performed over the FL round and
the corresponding loss dictionairy.
the corresponding loss dictionary.
"""
self.update_control_variates(local_steps)

Expand Down
Loading