Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Hardware][AMD] Enable level 3 compilation on rocm #10836

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

charlifu
Copy link
Contributor

@charlifu charlifu commented Dec 2, 2024

This PR fixs the fusion pass not enabled on rocm by:

  • add fp8 dtype selection to the fusion pass, since rocm use torch.float8funz
  • use tensor slice operation to replace the torch.narrow op which creates extra ops in the IR generated by torch.compile and makes the rms+fp8_quant fusion [torch.compile] Fuse RMSNorm with quant #9138 not working.

Copy link

github-actions bot commented Dec 2, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@charlifu charlifu force-pushed the enable_amd_torch_compile branch from 33e7055 to f441c65 Compare December 2, 2024 20:59
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you consider adding AMD to this fusion test case?

@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):

@charlifu
Copy link
Contributor Author

charlifu commented Dec 3, 2024

@ProExpertProg @mgoin Some updates:

I was trying to enable the unit test of the fusion pass on rocm. I found that with num_token < 17, we are still seeing the extra ops even with the slice operations which fail the test. I think this is because we are padding the input when num_token < 17 and this adds the extra slice_scatter and slice ops into the generated IR.

May I ask it is ok to disable the padding by default? Link

@ProExpertProg
Copy link
Contributor

I am looking into this and found the same issue.

There's also another problem, unrelated to fusion. When we compile with dynamic shape, the max expression will not be a part of the trace, and the graph will contain only the taken branch inside max. A dimension that's marked dynamic still has an underlying value, and that value is erroneously used in the max to pick the larger one, even though that shouldn't be known.

That means if we compile with a dynamic num_tokens that's larger than 17, the graph will always use s0 (the dynamic dimension in place of num_tokens) for the size of the tensor, even if the actual value of s0 is less than 17. The other case is worse: if we originally compile with s0 < 17, the size will always be 17, even if s0 > 17 during execution.

@mgoin I'd advocate for removing the padding in the short term - how often do we deal with num_tokens < 17? In the long term, we can fix the tracing and the fusion for the padded case

@ProExpertProg
Copy link
Contributor

Minor correction: the max does get traced properly into a torch.sym_max(s0, 17) - but for some reason it isn't used so it gets optimized out during the autograd phase.

@mgoin
Copy link
Member

mgoin commented Dec 3, 2024

Could we simply remove the padding in the non-CUDA case? The reason why it is there is because of bad scaled_mm performance on CUDA

@ProExpertProg
Copy link
Contributor

Ok if we want to keep the padding I think I have a solution for the fusion (but not for dynamic shape compilation). I can implement it tomorrow

@ProExpertProg
Copy link
Contributor

For what it's worth, torch.narrow gets lowered to a slice operation anyway, so at least in the torch.compile regime there should be no difference with slicing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants