-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rfc: graph: propose to support Grouped-query Attention #2018
base: rfcs
Are you sure you want to change the base?
Conversation
@gyhintel, thanks for the RFC. Some questions:
|
Yes, it means that the pattern cannot be used to optimize a framework graph directly. Users will have to map their GQA implementation graph to our pattern. This is the second cons of option 2.
In the current Pytorch implementation, there are no extra actions from their side. But if the implementation in the community changes, still needs to handle the new implementation. This is the second cons of option1. |
1. The pattern is less intuitive from GQA definition. | ||
2. The pattern cannot be used to optimize a framework graph directly. Frameworks | ||
will have to implement GQA fusion by themselves and leverage this option to | ||
optimized the fused GQA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this turns out to be a serious con, it would be reasonable to add a pass to match the Option 1 subgraph and convert it to the Option 2 subgraph, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- If it is a serious con, we need to implement option 1 adding new ops and new patterns. It is a backend implementation that matches the Option 1 subgraph and converts it to the Option 2 subgraph. We can also implement it in other ways in the backend.
- If the pass can be done on the framework side, we only need to implement option 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will have to support and match the subgraph in Option 1 once the request pops up. With that, oneDNN will support and maintain several different patterns for the same GQA functionality. Maybe it's not an issue as even for now we choose to Option 1 as the initial step, the pattern may still change in the future as mentioned in the cons of Option 1.
(see broadcasting in | ||
[ONNX](https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md) and | ||
[NumPy](https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules).), | ||
but actually it's added to the MatMul operation of cuDNN in order to support |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to add the link here for reference: https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-graph-library.html#cudnn-backend-operation-matmul-descriptor .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, thanks.
that the new broadcasting rule is only supported by the fused attention. | ||
2. Same as option 2, still the pattern cannot be used to optimize a framework | ||
graph directly. Frameworks will have to implement GQA fusion by themselves | ||
and leverage this option to optimized the fused GQA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another con here maybe that we rely on oneDNN matmul primitive kernels for reference implementation and testing in benchdnn which do not support the new broadcasting rule. Extending the broadcast semantics on graph side will also request additional effort for reference implementation and testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, thanks.
|
||
## GQA in PyTorch | ||
|
||
Unlike SDPA, PyTorch does not support GQA as a fused operations. In Huggingface |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI - the PyTorch PR just got merged this week: pytorch/pytorch#132689
Any clarity at this point if users are fine with implementing Option 2 on their side, or Option 1 must be implemented instead? |
@chunyuan-w, @sanchitintel, Could you help take a look at this RFC? thanks! |
|
||
| Matrix A | Matrix B | Matrix C = A x B | | ||
| -- | -- | -- | | ||
| B1 x 1 x B3 x M x K | B1 x B2 x 1 x M x K | B1 x B2 x B3 x M x N | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| B1 x 1 x B3 x M x K | B1 x B2 x 1 x M x K | B1 x B2 x B3 x M x N | | |
| B1 x 1 x B3 x M x K | B1 x B2 x 1 x K x N | B1 x B2 x B3 x M x N | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks.
Description
This is to propose to support Grouped-query Attention in oneDNN Graph API.
Link to the rendered document.