-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OPT] improve rms_norm kernel #258
Conversation
@kkHuang-amd Could you please fill in a brief description of what this is doing, which improvement is expected and why |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conditionally approving. My points of concern:
- int32 for indices. Are we sure in this use case it won't cause overflows? cc @rasmith
- Possible performance implications on platforms other than MI (Navi). cc @maleksan85
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some places where integer overflow could occur.
@rasmith Is this good now? |
Looks good! |
The base branch was changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved, pending conflict resolution
…BFloat16ELi8EEENSt9enable_ifIXooooeqT0_Li0Entsr12_typeConvertIT_EE6existseqLi2ELi2EEvE4typeEPS4_PKS4_S9_fiii"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@kkHuang-amd this PR introduces a correctness regression on Llama-3.2-90B-Vision-Instruct-FP8-KV model without triton flash attention. |
The issue is the following: |
Use vectorized memory access to improve rms_norm performance.
In the ROCm platform, we can do the 16 bytes memory access in one instruction, by using this instruction, we can do vectorized memory access for rms_nrom kernel.
From the testing, we can see rms_kernel has improvement