-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow feature_shape
to be passed to fx plan initialize
#983
Allow feature_shape
to be passed to fx plan initialize
#983
Conversation
…ation Signed-off-by: Patrick Foley <[email protected]>
Signed-off-by: Patrick Foley <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this method works for single-input models with static shapes, it may not work for multi-tensor input models, or single-input models with custom input dtype
s.
A general approach could be for the aggregator to take InputSpec
as an argument. InputSpec
is could be a dictionary of list of TensorSpec[TensorShape, DType]
.
Example multi-input-output model:
# Can generalize if `ast` parses it as a dict instead
--input_shape {'input_0': [1, 240, 240, 4], 'output_1': [1, 240, 240, 1]}
# Assumes first input to take the given shape, if `ast` parses as list.
--input_shape [1, 240, 240, 1]
Signed-off-by: Patrick Foley <[email protected]>
Signed-off-by: Patrick Foley <[email protected]>
This is a good point. I tested that the existing code (that makes use of |
This looks good to me, overall. It successfully avoids the need of loading the data at the aggregator to initialize the model, which is a great addition. One comment: this won't always guarantee that the model is initialized using the specified input values or that the dataloader will load in data with the specified shape at the collaborators. For example, the In my opinion, this is not a breaking issue and this PR is good to go. However, going forward we should consider ways to reduce any potential confusion caused by this from a user perspective, such as having the existing workspace templates erroring out if an invalid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks @psfoley !
"""Mock objects to eliminate extraneous dependencies""" | ||
|
||
|
||
class MockDataLoader: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose you don't extend DataLoader
, because only the get_feature_shape()
method is required during plan initialization? In this case, would it make sense to call this class FeatureShapeLoader
?
* Remove need for aggregator to have dataset for model weight initialization Signed-off-by: Patrick Foley <[email protected]> * Remove extra print arguments Signed-off-by: Patrick Foley <[email protected]> * Address review comments Signed-off-by: Patrick Foley <[email protected]> * Remove extraneous print statement Signed-off-by: Patrick Foley <[email protected]> --------- Signed-off-by: Patrick Foley <[email protected]>
This PR:
feature_shape
argument that can be passed tofx plan initialize
. By passing this argument (which allows for list types arguments, such as[1,28,28]
), loading the data loader for the model owner / aggregator can be avoided. This is important because the aggregator may not have access to local data to generate the initial model.