Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse batch normalization into convolution kernel #2629

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mvpant
Copy link
Contributor

@mvpant mvpant commented Nov 18, 2024

This introduces a simplification that merges the batch normalization inference operation with convolution weights (kernel). The key idea is that while the batch normalization parameters change during the training phase, but remain constant during inference. This means that the convolution kernel can be adjusted to incorporate the effects of batch normalization. This optimization is applied by default to the ResNet model in the ONNX framework.

I will provide more details about the test case and additional formulas later.

For now, I would like to know if there is interest in this?

@mvpant
Copy link
Contributor Author

mvpant commented Nov 18, 2024

Regarding terminology, what is preferred in StableHLO for convolution rhs: kernel or weight?

@GleasonK
Copy link
Member

The key idea is that while the batch normalization parameters change during the training phase, but remain constant during inference. This means that the convolution kernel can be adjusted to incorporate the effects of batch normalization.

Is this to say - during training these values won't be constant ops, and this pattern won't apply, but during inference it will? This seems reasonable. Overall certainly interested in growing the set of patterns available in the StableHLO repo.

We've discussed before that we'll need a way to adjust the knobs in terms of what patterns get applied, and that's a problem I plan to take on early next year. In the meantime, probably fine to have this pattern in this pass. If we decided it wasn't desirable on the default path, we can always make this it's own pass.

Regarding terminology, what is preferred in StableHLO for convolution rhs: kernel or weight?

cc @ghpvnist regarding the terminology question, any preference from a spec perspective?

@ghpvnist
Copy link
Member

I like kernel but both are equally well understood imo, so up to the code author :) Since this isn't affecting the spec, anything works!

@mvpant
Copy link
Contributor Author

mvpant commented Nov 19, 2024

Is this to say - during training these values won't be constant ops, and this pattern won't apply, but during inference it will? This seems reasonable.

Yes, I assume that’s why there are several operations like stablehlo.batch_norm_grad, stablehlo.batch_norm_inference, and stablehlo.batch_norm_training. The stablehlo.batch_norm_inference is designed to be used during the inference phase, normalizing input data using the statistics computed during training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants