Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow feature_shape to be passed to fx plan initialize #983

Merged
merged 4 commits into from
Jun 24, 2024

Conversation

psfoley
Copy link
Contributor

@psfoley psfoley commented Jun 5, 2024

This PR:

  • Fixes the feature_shape argument that can be passed to fx plan initialize. By passing this argument (which allows for list types arguments, such as [1,28,28]), loading the data loader for the model owner / aggregator can be avoided. This is important because the aggregator may not have access to local data to generate the initial model.

Copy link
Collaborator

@MasterSkepticista MasterSkepticista left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this method works for single-input models with static shapes, it may not work for multi-tensor input models, or single-input models with custom input dtypes.

A general approach could be for the aggregator to take InputSpec as an argument. InputSpec is could be a dictionary of list of TensorSpec[TensorShape, DType].

Example multi-input-output model:

# Can generalize if `ast` parses it as a dict instead
--input_shape {'input_0': [1, 240, 240, 4], 'output_1': [1, 240, 240, 1]}

# Assumes first input to take the given shape, if `ast` parses as list.
--input_shape [1, 240, 240, 1]

openfl/interface/plan.py Outdated Show resolved Hide resolved
openfl/utilities/click_types.py Outdated Show resolved Hide resolved
openfl/utilities/click_types.py Show resolved Hide resolved
openfl/interface/plan.py Outdated Show resolved Hide resolved
psfoley added 2 commits June 6, 2024 21:54
Signed-off-by: Patrick Foley <[email protected]>
@psfoley
Copy link
Contributor Author

psfoley commented Jun 6, 2024

While this method works for single-input models with static shapes, it may not work for multi-tensor input models, or single-input models with custom input dtypes.

A general approach could be for the aggregator to take InputSpec as an argument. InputSpec is could be a dictionary of list of TensorSpec[TensorShape, DType].

Example multi-input-output model:

# Can generalize if `ast` parses it as a dict instead
--input_shape {'input_0': [1, 240, 240, 4], 'output_1': [1, 240, 240, 1]}

# Assumes first input to take the given shape, if `ast` parses as list.
--input_shape [1, 240, 240, 1]

This is a good point. I tested that the existing code (that makes use of ast.literal_eval) can handle dictionaries as well. The main caveat is that the dictionary needs to be wrapped in quotes to be passed via command line. I've tried to make it clear in the help string that this is a requirement, and provided your example (wrapped in double quotes) as a reference for users.

@kta-intel
Copy link
Collaborator

kta-intel commented Jun 7, 2024

This looks good to me, overall. It successfully avoids the need of loading the data at the aggregator to initialize the model, which is a great addition.

One comment: this won't always guarantee that the model is initialized using the specified input values or that the dataloader will load in data with the specified shape at the collaborators. For example, the torch_cnn_mnist workspace hard codes the model parameters in src/taskrunner.py so it will ultimately ignore the values passed through fx plan initialize --input_shape and initialize regardless, and even if we use .get_feature_shape() during initialization, only the channel size matters since the HxW isn't required in for torch conv layers.

In my opinion, this is not a breaking issue and this PR is good to go. However, going forward we should consider ways to reduce any potential confusion caused by this from a user perspective, such as having the existing workspace templates erroring out if an invalid input_shape is provided or issuing a user warning if it is ignored.

Copy link
Collaborator

@teoparvanov teoparvanov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks @psfoley !

"""Mock objects to eliminate extraneous dependencies"""


class MockDataLoader:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you don't extend DataLoader, because only the get_feature_shape() method is required during plan initialization? In this case, would it make sense to call this class FeatureShapeLoader?

@psfoley psfoley merged commit eeafaf5 into securefederatedai:develop Jun 24, 2024
25 of 26 checks passed
manuelhsantana pushed a commit that referenced this pull request Jul 10, 2024
* Remove need for aggregator to have dataset for model weight initialization

Signed-off-by: Patrick Foley <[email protected]>

* Remove extra print arguments

Signed-off-by: Patrick Foley <[email protected]>

* Address review comments

Signed-off-by: Patrick Foley <[email protected]>

* Remove extraneous print statement

Signed-off-by: Patrick Foley <[email protected]>

---------

Signed-off-by: Patrick Foley <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants