Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPT] improve rms_norm kernel #258

Merged
merged 13 commits into from
Nov 20, 2024
Merged

[OPT] improve rms_norm kernel #258

merged 13 commits into from
Nov 20, 2024

Conversation

kkHuang-amd
Copy link

@kkHuang-amd kkHuang-amd commented Nov 1, 2024

Use vectorized memory access to improve rms_norm performance.

In the ROCm platform, we can do the 16 bytes memory access in one instruction, by using this instruction, we can do vectorized memory access for rms_nrom kernel.

From the testing, we can see rms_kernel has improvement

@gshtras
Copy link
Collaborator

gshtras commented Nov 1, 2024

@kkHuang-amd Could you please fill in a brief description of what this is doing, which improvement is expected and why

HaiShaw
HaiShaw previously approved these changes Nov 1, 2024
Copy link

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

gshtras
gshtras previously approved these changes Nov 4, 2024
Copy link
Collaborator

@gshtras gshtras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conditionally approving. My points of concern:

  1. int32 for indices. Are we sure in this use case it won't cause overflows? cc @rasmith
  2. Possible performance implications on platforms other than MI (Navi). cc @maleksan85

Copy link

@rasmith rasmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some places where integer overflow could occur.

csrc/layernorm_kernels.cu Outdated Show resolved Hide resolved
csrc/layernorm_kernels.cu Outdated Show resolved Hide resolved
csrc/layernorm_kernels.cu Outdated Show resolved Hide resolved
csrc/layernorm_kernels.cu Outdated Show resolved Hide resolved
csrc/layernorm_kernels.cu Outdated Show resolved Hide resolved
csrc/layernorm_kernels.cu Outdated Show resolved Hide resolved
@kkHuang-amd kkHuang-amd dismissed stale reviews from gshtras and HaiShaw via 22669a6 November 6, 2024 07:46
@kkHuang-amd kkHuang-amd requested a review from rasmith November 6, 2024 09:37
gshtras
gshtras previously approved these changes Nov 7, 2024
@HaiShaw
Copy link

HaiShaw commented Nov 11, 2024

@rasmith Is this good now?

rasmith
rasmith previously approved these changes Nov 11, 2024
@rasmith
Copy link

rasmith commented Nov 11, 2024

@rasmith Is this good now?

Looks good!

@gshtras gshtras changed the base branch from main to develop November 12, 2024 16:47
@gshtras gshtras dismissed stale reviews from rasmith and themself November 12, 2024 16:47

The base branch was changed.

gshtras
gshtras previously approved these changes Nov 12, 2024
Copy link
Collaborator

@gshtras gshtras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, pending conflict resolution

@kkHuang-amd kkHuang-amd requested a review from gshtras November 18, 2024 08:53
…BFloat16ELi8EEENSt9enable_ifIXooooeqT0_Li0Entsr12_typeConvertIT_EE6existseqLi2ELi2EEvE4typeEPS4_PKS4_S9_fiii"
Copy link

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gshtras gshtras merged commit 15c78e7 into develop Nov 20, 2024
6 of 7 checks passed
@gshtras gshtras deleted the kk/rms_norm_opt branch November 20, 2024 15:40
@gshtras
Copy link
Collaborator

gshtras commented Nov 26, 2024

@kkHuang-amd this PR introduces a correctness regression on Llama-3.2-90B-Vision-Instruct-FP8-KV model without triton flash attention.
I'm planning to revert it pending further investigation

@gshtras
Copy link
Collaborator

gshtras commented Nov 26, 2024

The issue is the following:
There used to be 2 versions of the kernel - MI and Navi.
After this PR, on MI the Navi version is called if vec_size is not a multiple of 8, which seems to be the case in the vision model. And this causes a correctness regression

gshtras added a commit that referenced this pull request Nov 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants