Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FedPer Implementation #72

Merged
merged 5 commits into from
Nov 29, 2023
Merged

FedPer Implementation #72

merged 5 commits into from
Nov 29, 2023

Conversation

emersodb
Copy link
Collaborator

PR Type

Feature

Short Description

Clickup Ticket: https://app.clickup.com/t/8686ckn31

This PR adds in the FedPer method into the repository. The addition is fairly straightforward within our infrastructure. It's essentially just a globally trained feature extractor with a locally trained classification head in each client. I added an example using the infrastructure while also training it with MOON. So the example client inherits from the MOON client to apply the auxiliary losses to the model along with local personalization, which I thought was nice.

In the course of adding the method, I noticed a few small bugs in MOON and the auxiliary losses (like PerFCL) in the FENDA approaches. These have been fixed and tests have been added to ensure that they are indeed fixed.

Tests Added

Added losses associated with MOON's contrastive loss calculations and the loss computations of MOON, FENDA, FedProx, and FedPer

…calculated when using auxiliary losses in FENDA. Added tests for fedprox and fenda to ensure the losses are being formed correctly. Also making a small change to the fenda example to added perfcl loss in the example for testing
from fl4health.utils.sampler import MinorityLabelBasedSampler


class MnistFedPerClient(MoonClient):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: We inherit from a MOON client here intentionally to be able to use auxiliary losses associated with the global module's feature space in addition to the personalized architecture.

@@ -26,7 +26,7 @@ def __init__(
device: torch.device,
minority_numbers: Set[int],
) -> None:
super().__init__(data_path=data_path, metrics=metrics, device=device)
super().__init__(data_path=data_path, metrics=metrics, device=device, perfcl_loss_weights=(1.0, 1.0))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding in the perfcl loss to the FENDA example. This is just for testing when running the example to make sure nothing is broken there.

@@ -207,7 +205,7 @@ def compute_loss(
"""

loss = self.criterion(preds["prediction"], target)
total_loss = loss
total_loss = loss.clone()
Copy link
Collaborator Author

@emersodb emersodb Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without clone, total_loss and loss share memory. This means that anything that is added to total_loss below is also added to loss. This means that checkpoint and backward in the loss object end up being identical, which we don't want. I added a unit test to make sure the clone here fixes the issue.

_, old_model_features = old_model(input)
old_features[i] = old_model_features["features"]
_, global_model_features = self.global_model(input)
features.update({"global_features": global_model_features["features"], "old_features": old_features})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the refactor of the use of a features dictionary, this predict function ended up being broken. I believe these changes fix the issue. However, correct me if I'm wrong.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true; I have fixed it adding contrastive losses PR and it seems correct on the version of main that I have. However, it might be changed in further merges. Thanks for pointing out.

@@ -89,8 +89,6 @@ def get_contrastive_loss(
return self.ce_criterion(logits, labels)

def set_parameters(self, parameters: NDArrays, config: Config) -> None:
assert isinstance(self.model, MoonModel)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary, unless I'm mistaken. Removing it also allows for FedPer models to be used with MOON Clients, which is nice...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it seems like mypy doesn't have problem with omitting it (it was originally for that). However if we want to build up FedPer over MOON, why don't we inherit it from the MOON model? This would make it easier for users to understand that FedPer models can be used with MOON Clients.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair question. The main reason I didn't do that was because FedPer model's exchange a partial subset of weights and don't, at least by default, admit projection modules for their features. They are very related. So it's possible that unifying them is a good idea. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

projection_module is also optional in moon_base, so you can easily pass by None for it in inheritance, and everything should work well. I kinda prefer fedper_base to inherit from both moon and partial_layer_exchange_model so users can get the relationship between all of them and the added functionality of fed_per based on both.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that makes sense. I'll do that right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the assert back in now that FedPer inherits from MOON


class ApflModule(nn.Module):

class ApflModule(PartialLayerExchangeModel):
Copy link
Collaborator Author

@emersodb emersodb Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having models that exchange a subset of their weights inherit from an abstract base class that forces the implementation of the layers_to_exchange function. This is just a formalism for now, but should be useful in future iterations of the parameter exchanger mechanisms.

@@ -69,5 +71,7 @@ def update_alpha(self) -> None:
self.alpha = alpha

def layers_to_exchange(self) -> List[str]:
layers_to_exchange: List[str] = [layer for layer in self.state_dict().keys() if "global_model" in layer]
layers_to_exchange: List[str] = [
layer for layer in self.state_dict().keys() if layer.startswith("global_model.")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this function to mirror FENDA and FedPer. In particular, we only want the layer name to start with global_model. rather than have it appear anywhere else in a name.

super().__init__()
self.global_feature_extractor = global_feature_extractor
self.local_prediction_head = local_prediction_head
self.flatten_features = flatten_features
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flatten features is used to make the model compatible with MOON which requires the intermediate feature representations to be flattened for similarity calculations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you flatten it as default? Is there any specific use case in FedPer for features compared to MOON?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I don't flatten by default is that it changes the shape of the features tensor. If a user was going to do something downstream with the features (other than the MOON calculations), I think they would be surprised if they weren't in the expected shape.

model: nn.Module = FedPerModel(
global_feature_extractor=FedPerGloalFeatureExtractor(),
local_prediction_head=FedPerLocalPredictionHead(),
flatten_features=True,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flatten features is used to make the model compatible with MOON which requires the intermediate feature representations to be flattened for similarity calculations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should write a comment regarding that? Here or in fedper_base.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely. Thanks for pointing that out.

# Return preds and features as seperate dictionairy as in fenda base
return {"prediction": preds}, {"features": features.view(len(features), -1)}
# Return preds and features as seperate dictionary as in fenda base
return {"prediction": preds}, {"features": features.reshape(len(features), -1)}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reshape is slightly more general than view

@emersodb emersodb marked this pull request as ready for review November 22, 2023 15:21
_, old_model_features = old_model(input)
old_features[i] = old_model_features["features"]
_, global_model_features = self.global_model(input)
features.update({"global_features": global_model_features["features"], "old_features": old_features})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true; I have fixed it adding contrastive losses PR and it seems correct on the version of main that I have. However, it might be changed in further merges. Thanks for pointing out.

@@ -89,8 +89,6 @@ def get_contrastive_loss(
return self.ce_criterion(logits, labels)

def set_parameters(self, parameters: NDArrays, config: Config) -> None:
assert isinstance(self.model, MoonModel)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it seems like mypy doesn't have problem with omitting it (it was originally for that). However if we want to build up FedPer over MOON, why don't we inherit it from the MOON model? This would make it easier for users to understand that FedPer models can be used with MOON Clients.

super().__init__()
self.global_feature_extractor = global_feature_extractor
self.local_prediction_head = local_prediction_head
self.flatten_features = flatten_features
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you flatten it as default? Is there any specific use case in FedPer for features compared to MOON?

@@ -0,0 +1,48 @@
# FedPer Federated Learning Example
This example provides an example of training a FedPer type model on a non-IID subset of the MNIST data. The FL server
expects three clients to be spun up (i.e. it will wait until three clients report in before starting training). Each client
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Number of clients are two in the config, probably we need to update config or readme.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm looking in the wrong place. but the config here examples/fedper_example/config.yaml has n_clients: 3. So I think the readme is accurate?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh true I somehow got confused with feature alignment example and thought this is added readme for it.

training and validation sets of MNIST in order to synthetically induce local variations in the statistical properties
of the clients training/validation data. In theory, the models should be able to perform well on their local data
while learning from other clients data that has different statistical properties. The subsampling is specified by
sending a list of integers between 0-9 to the clients when they are run with the argument `--minority_numbers`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Client in this example does not have --minority_numbers argument and it does not use MinorityLabelBasedSampler. I think it assigns random train-valid split.

Copy link
Collaborator Author

@emersodb emersodb Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client here: examples/fedper_example/client.py does take the --minority_numbers and uses MinorityLabelBasedSampler in the get_data_loaders function. Please correct me if I'm wrong though

* `local_epochs`: number of epochs each client will train for locally
* `batch_size`: size of the batches each client will train on
* `n_server_rounds`: The number of rounds to run FL
* `downsampling_ratio`: The amount of downsampling to perform for minority digits
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, Config does not have a downsampling_ratio variable but has an extra source_specified variable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the config does have this argument, but based on the source_specified comment, I think you might be looking at the wrong config, as that sounds like the config for the feature_alignment_example?

Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the three
clients. This is done by simply running (remembering to activate your environment)
```
python -m examples.fedper_example.client --dataset_path /path/to/data --minority_numbers <sequence of numbers>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe because of same issue running this command would also get an error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comments above. I think we're okay here. I did run the example with these args and it worked appropriately.


The argument `minority_numbers` specifies which digits (0-9) in the MNIST dataset the client will subsample to
simulate non-IID data between clients. For example `--minority_numbers 1 2 3 4 5` will ensure that the client
downsamples these digits (using the `downsampling_ratio` specified to the config).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section also should be omitted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the above, I think we're good to leave this in?

model: nn.Module = FedPerModel(
global_feature_extractor=FedPerGloalFeatureExtractor(),
local_prediction_head=FedPerLocalPredictionHead(),
flatten_features=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should write a comment regarding that? Here or in fedper_base.py.

@sanaAyrml
Copy link
Collaborator

For me it seems all good now. Thanks for the PR.

Copy link
Collaborator

@fatemetkl fatemetkl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks good to me as well! No other comments, thanks for adding the tests!

@emersodb emersodb merged commit 7305bec into main Nov 29, 2023
2 checks passed
@emersodb emersodb deleted the dbe/fedper_implementation branch November 29, 2023 16:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants