-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FedPer Implementation #72
Conversation
…o modify the readme and test the run.
…calculated when using auxiliary losses in FENDA. Added tests for fedprox and fenda to ensure the losses are being formed correctly. Also making a small change to the fenda example to added perfcl loss in the example for testing
from fl4health.utils.sampler import MinorityLabelBasedSampler | ||
|
||
|
||
class MnistFedPerClient(MoonClient): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: We inherit from a MOON client here intentionally to be able to use auxiliary losses associated with the global module's feature space in addition to the personalized architecture.
@@ -26,7 +26,7 @@ def __init__( | |||
device: torch.device, | |||
minority_numbers: Set[int], | |||
) -> None: | |||
super().__init__(data_path=data_path, metrics=metrics, device=device) | |||
super().__init__(data_path=data_path, metrics=metrics, device=device, perfcl_loss_weights=(1.0, 1.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding in the perfcl loss to the FENDA example. This is just for testing when running the example to make sure nothing is broken there.
@@ -207,7 +205,7 @@ def compute_loss( | |||
""" | |||
|
|||
loss = self.criterion(preds["prediction"], target) | |||
total_loss = loss | |||
total_loss = loss.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without clone, total_loss
and loss
share memory. This means that anything that is added to total_loss
below is also added to loss. This means that checkpoint and backward in the loss object end up being identical, which we don't want. I added a unit test to make sure the clone here fixes the issue.
_, old_model_features = old_model(input) | ||
old_features[i] = old_model_features["features"] | ||
_, global_model_features = self.global_model(input) | ||
features.update({"global_features": global_model_features["features"], "old_features": old_features}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the refactor of the use of a features dictionary, this predict function ended up being broken. I believe these changes fix the issue. However, correct me if I'm wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is true; I have fixed it adding contrastive losses PR and it seems correct on the version of main that I have. However, it might be changed in further merges. Thanks for pointing out.
@@ -89,8 +89,6 @@ def get_contrastive_loss( | |||
return self.ce_criterion(logits, labels) | |||
|
|||
def set_parameters(self, parameters: NDArrays, config: Config) -> None: | |||
assert isinstance(self.model, MoonModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unnecessary, unless I'm mistaken. Removing it also allows for FedPer models to be used with MOON Clients, which is nice...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it seems like mypy doesn't have problem with omitting it (it was originally for that). However if we want to build up FedPer over MOON, why don't we inherit it from the MOON model? This would make it easier for users to understand that FedPer models can be used with MOON Clients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair question. The main reason I didn't do that was because FedPer model's exchange a partial subset of weights and don't, at least by default, admit projection modules for their features. They are very related. So it's possible that unifying them is a good idea. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
projection_module is also optional in moon_base, so you can easily pass by None for it in inheritance, and everything should work well. I kinda prefer fedper_base to inherit from both moon and partial_layer_exchange_model so users can get the relationship between all of them and the added functionality of fed_per based on both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think that makes sense. I'll do that right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the assert back in now that FedPer inherits from MOON
|
||
class ApflModule(nn.Module): | ||
|
||
class ApflModule(PartialLayerExchangeModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having models that exchange a subset of their weights inherit from an abstract base class that forces the implementation of the layers_to_exchange
function. This is just a formalism for now, but should be useful in future iterations of the parameter exchanger mechanisms.
@@ -69,5 +71,7 @@ def update_alpha(self) -> None: | |||
self.alpha = alpha | |||
|
|||
def layers_to_exchange(self) -> List[str]: | |||
layers_to_exchange: List[str] = [layer for layer in self.state_dict().keys() if "global_model" in layer] | |||
layers_to_exchange: List[str] = [ | |||
layer for layer in self.state_dict().keys() if layer.startswith("global_model.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing this function to mirror FENDA and FedPer. In particular, we only want the layer name to start with global_model.
rather than have it appear anywhere else in a name.
super().__init__() | ||
self.global_feature_extractor = global_feature_extractor | ||
self.local_prediction_head = local_prediction_head | ||
self.flatten_features = flatten_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flatten features is used to make the model compatible with MOON which requires the intermediate feature representations to be flattened for similarity calculations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you flatten it as default? Is there any specific use case in FedPer for features compared to MOON?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I don't flatten by default is that it changes the shape of the features tensor. If a user was going to do something downstream with the features (other than the MOON calculations), I think they would be surprised if they weren't in the expected shape.
model: nn.Module = FedPerModel( | ||
global_feature_extractor=FedPerGloalFeatureExtractor(), | ||
local_prediction_head=FedPerLocalPredictionHead(), | ||
flatten_features=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flatten features is used to make the model compatible with MOON which requires the intermediate feature representations to be flattened for similarity calculations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you should write a comment regarding that? Here or in fedper_base.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, definitely. Thanks for pointing that out.
fl4health/model_bases/moon_base.py
Outdated
# Return preds and features as seperate dictionairy as in fenda base | ||
return {"prediction": preds}, {"features": features.view(len(features), -1)} | ||
# Return preds and features as seperate dictionary as in fenda base | ||
return {"prediction": preds}, {"features": features.reshape(len(features), -1)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshape is slightly more general than view
_, old_model_features = old_model(input) | ||
old_features[i] = old_model_features["features"] | ||
_, global_model_features = self.global_model(input) | ||
features.update({"global_features": global_model_features["features"], "old_features": old_features}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is true; I have fixed it adding contrastive losses PR and it seems correct on the version of main that I have. However, it might be changed in further merges. Thanks for pointing out.
@@ -89,8 +89,6 @@ def get_contrastive_loss( | |||
return self.ce_criterion(logits, labels) | |||
|
|||
def set_parameters(self, parameters: NDArrays, config: Config) -> None: | |||
assert isinstance(self.model, MoonModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it seems like mypy doesn't have problem with omitting it (it was originally for that). However if we want to build up FedPer over MOON, why don't we inherit it from the MOON model? This would make it easier for users to understand that FedPer models can be used with MOON Clients.
super().__init__() | ||
self.global_feature_extractor = global_feature_extractor | ||
self.local_prediction_head = local_prediction_head | ||
self.flatten_features = flatten_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you flatten it as default? Is there any specific use case in FedPer for features compared to MOON?
@@ -0,0 +1,48 @@ | |||
# FedPer Federated Learning Example | |||
This example provides an example of training a FedPer type model on a non-IID subset of the MNIST data. The FL server | |||
expects three clients to be spun up (i.e. it will wait until three clients report in before starting training). Each client |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Number of clients are two in the config, probably we need to update config or readme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm looking in the wrong place. but the config here examples/fedper_example/config.yaml
has n_clients: 3
. So I think the readme is accurate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh true I somehow got confused with feature alignment example and thought this is added readme for it.
training and validation sets of MNIST in order to synthetically induce local variations in the statistical properties | ||
of the clients training/validation data. In theory, the models should be able to perform well on their local data | ||
while learning from other clients data that has different statistical properties. The subsampling is specified by | ||
sending a list of integers between 0-9 to the clients when they are run with the argument `--minority_numbers`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Client in this example does not have --minority_numbers
argument and it does not use MinorityLabelBasedSampler
. I think it assigns random train-valid split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The client here: examples/fedper_example/client.py
does take the --minority_numbers
and uses MinorityLabelBasedSampler
in the get_data_loaders
function. Please correct me if I'm wrong though
* `local_epochs`: number of epochs each client will train for locally | ||
* `batch_size`: size of the batches each client will train on | ||
* `n_server_rounds`: The number of rounds to run FL | ||
* `downsampling_ratio`: The amount of downsampling to perform for minority digits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, Config does not have a downsampling_ratio
variable but has an extra source_specified
variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the config does have this argument, but based on the source_specified
comment, I think you might be looking at the wrong config, as that sounds like the config for the feature_alignment_example
?
Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the three | ||
clients. This is done by simply running (remembering to activate your environment) | ||
``` | ||
python -m examples.fedper_example.client --dataset_path /path/to/data --minority_numbers <sequence of numbers> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe because of same issue running this command would also get an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comments above. I think we're okay here. I did run the example with these args and it worked appropriately.
|
||
The argument `minority_numbers` specifies which digits (0-9) in the MNIST dataset the client will subsample to | ||
simulate non-IID data between clients. For example `--minority_numbers 1 2 3 4 5` will ensure that the client | ||
downsamples these digits (using the `downsampling_ratio` specified to the config). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section also should be omitted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the above, I think we're good to leave this in?
model: nn.Module = FedPerModel( | ||
global_feature_extractor=FedPerGloalFeatureExtractor(), | ||
local_prediction_head=FedPerLocalPredictionHead(), | ||
flatten_features=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you should write a comment regarding that? Here or in fedper_base.py
.
For me it seems all good now. Thanks for the PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looks good to me as well! No other comments, thanks for adding the tests!
PR Type
Feature
Short Description
Clickup Ticket: https://app.clickup.com/t/8686ckn31
This PR adds in the FedPer method into the repository. The addition is fairly straightforward within our infrastructure. It's essentially just a globally trained feature extractor with a locally trained classification head in each client. I added an example using the infrastructure while also training it with MOON. So the example client inherits from the MOON client to apply the auxiliary losses to the model along with local personalization, which I thought was nice.
In the course of adding the method, I noticed a few small bugs in MOON and the auxiliary losses (like PerFCL) in the FENDA approaches. These have been fixed and tests have been added to ensure that they are indeed fixed.
Tests Added
Added losses associated with MOON's contrastive loss calculations and the loss computations of MOON, FENDA, FedProx, and FedPer