Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bamba Model #10909

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open

Add Bamba Model #10909

wants to merge 34 commits into from

Conversation

fabianlim
Copy link

@fabianlim fabianlim commented Dec 5, 2024

This is the companion PR to an huggingface PR for adding Bamba, which is a hybrid mamba2 architecture with SwiGLU. The checkpoints are jointly trained by IBM, Princeton, and UIUC.

In this PR we have:

  • Created the bamba model inference architecture, which we would like acknowledge the jamba team for referencing their implementation, whereby we modified to support full attention layers with RoPE and mamba v2.
  • Ensured that we have TP support.
  • Ensured we support chunked prefill. Currently we have a partial solution, which works only when the cont batch boundaries line up with the chunked boundaries. This is now completely fixed.
  • Ensured that we conform to the recent PR for adding pipeline support for SSM models.
  • Adapted the mamba v2 scan kernels into vllm/model_executor/layers/mamba/ops. Only the fwd kernels are extracted. Some modifications and fixes are made.
  • created tests/models/decoder_only/language/test_bamba.py with an initial ibm-fms/Bamba-9.8b-1.8T-hf. This is practically identical to test_mamba.py, only chunked prefill tests are disabled as it is currently not supported.

Currently only FlashAttention backend is supported, as we check fields like context_lens_tensor. Have not yet investigated other backends.

We would like to also acknowledge the draft codestral mamba PR from @tlrmchlsmth, which we also referenced the mixer.

  • we made a few simplications for bamba (simplified mixer from mamba v2)
  • Cuda graph capturing seems to be working, but we understand that cudagraphs are disabled for long sequence lengths. For SSM models the strength is in this regime, so can we handle it better?

Hope to discuss the following with the maintainers

  1. do we have to remove all the bwd kernels? yes we should
  2. for the full attention layers, we increase the sin_cos cache to cover the sequence length, if it is longer than max_sequence_len. This differs for other current models (e.g., llama). How can we better support long sequence lengths? we should keep this consistent with other models, so we propose to allow the sin_cos cache extension only when VLLM_ALLOW_LONG_MAX_MODEL_LEN is specified.
  3. have some ideas to support chunked pre-fill, but will appreciate some discussion with the maintainers on how to proceed. working on changing the kernels to support chunked prefill.
  4. since the mixer2 is simplified from mamba, should we rename it? we can keep it as is, but we should document the differences from mamba_ssm

cc: @ani300, @raghukiran1224, @cyang49, @njhill

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim marked this pull request as draft December 5, 2024 01:35
Copy link

github-actions bot commented Dec 5, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@tlrmchlsmth
Copy link
Collaborator

Hi @fabianlim, thanks for the PR! It's really great to see progress being made on state-space models, especially for me as I unfortunately haven't been able to prioritize support for Mamba2

I'm happy to shepherd this PR and discuss any questions you have, especially to support chunked prefill. If you haven't already, can you join the developer slack for quicker discussion? (https://communityinviter.com/apps/vllm-dev/join-vllm-developers-slack)

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim
Copy link
Author

fabianlim commented Dec 12, 2024

@tlrmchlsmth I cleaned up the PR quite abit, perhaps it might be a good time to get some early eyes. The chunked prefill implementation is incomplete ATM, as we discussed offline.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first pass, just a few comments. At a high level it looks good.

Will you add a test for tensor parallelism?

Comment on lines 9 to 10
# will be ch
MODELS = ["ibm-fms/Bamba-9.8b-1.8T-hf"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment trails off, but will there be a small test model available?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raghukiran1224 any plans for a small test model? I think since we do outputs comparison it is not that good to just have a randomly initialised small model

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim @tlrmchlsmth would it be ok to test with a random model or would you rather have a tiny model (say 200M or so) to test with?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A tiny model with nonrandom weights would be much better!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw is there any update on this?

tests/models/decoder_only/language/test_bamba.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/mamba/ops/ssd_bmm.py Outdated Show resolved Hide resolved
Copy link

mergify bot commented Dec 13, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fabianlim.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 13, 2024
@fabianlim
Copy link
Author

@tlrmchlsmth i have addressed most of your comments now, not rebasing yet, waiting for you to look first. But I realized test_jamba.py has changed so I will need to do the rename and test again.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim At a high level, the changes look good, and the PR looks good overall. I'll do a more thorough review once it's unmarked as draft.

Could you add unit tests for the added kernels in layers/mamba/ops?

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@mergify mergify bot removed the needs-rebase label Dec 16, 2024
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim
Copy link
Author

fabianlim commented Dec 23, 2024

@tlrmchlsmth I have now marked the PR as ready and have addressed the remaining items, in particular the chunked prefill is the biggest change. This requires changes into the kernels. To test this, I have added unit tests for various chunked prefill scenarios.

However in test_mamba_prefill_chunking I removed two prompts due to difficulties getting the decoding results to exactly match; the unit tests are quite extensive though and the match was quite statisfactory there. The original tests for jamba also removed some prompts for the same reason.

In the unit tests, I have things like [(8, 8), (8, 8), (8, 8)], this tests the chunk prefill as follows

  • the first (cont batch) is 8 tokens of example 1, followed by 8 tokens of example 2, after the decoding the state is saved.
  • the second (cont batch) is 8 tokens of example 1, followed by 8 tokens of example 2. To ensure correct decoding, the saved state must be passed into the kernels (the updates to the kernels are to ingest this state data appropriately).
  • and so on ...

There are other tests that vary the chunk size, the number of tokens being passed in each batch, etc..

Know is the holiday season and pls take your time. Happy holidays!

@fabianlim fabianlim marked this pull request as ready for review December 23, 2024 11:47
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim do you know when a version of transformers that supports bamba will be released? This PR will likely be blocked until it's out

tests/kernels/test_mamba_ssm_ssd.py Outdated Show resolved Hide resolved
tests/kernels/test_mamba_ssm_ssd.py Outdated Show resolved Hide resolved
tests/kernels/test_mamba_ssm_ssd.py Outdated Show resolved Hide resolved
tests/kernels/test_mamba_ssm_ssd.py Outdated Show resolved Hide resolved
tests/kernels/test_mamba_ssm_ssd.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/mamba/mamba_mixer2.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/mamba/mamba_mixer2.py Outdated Show resolved Hide resolved
Comment on lines 9 to 10
# will be ch
MODELS = ["ibm-fms/Bamba-9.8b-1.8T-hf"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw is there any update on this?

vllm/model_executor/models/bamba.py Outdated Show resolved Hide resolved
Comment on lines 212 to 233
# because the bamba model may potentially handle long sequences,
# we should adjust the sin_cos cache if necessary to avoid out of bounds
# - first get the max_position
max_position = max(
getattr(attn_metadata, 'max_prefill_seq_len', 0),
getattr(attn_metadata, 'max_decode_seq_len', 0),
)
if max_position == 0:
# if we cannot get the max length from the metadata, then
# get it from the positions
max_position = positions.max().item()

# when VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 could potentially cause inputs
# longer than max_position_embeddings. We extend the rope cache
# to prevent CUDA errors. Be aware that the outputs could be of
# lower quality for long sequence lengths.
rotary = self.rotary_emb
if rotary.max_position_embeddings <= max_position:
# we set it to the next power of two that covers it
while rotary.max_position_embeddings <= max_position:
rotary.max_position_embeddings *= 2
rotary.cos_sin_cache = rotary._compute_cos_sin_cache()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this, have you considered using rope_scaling instead? If you use get_rope instead of constructing the RotaryEmbedding directly, I think it should work for bamba

See this unit test for an example of how it works:
https://github.com/sasha0552/vllm/blob/d427e5cfda8d2536b81e6021128e71b2dbc281aa/tests/test_config.py#L177

See also:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth ok I will follow your suggestion. this has been changed in 63f5340

@fabianlim
Copy link
Author

fabianlim commented Jan 2, 2025

hey @tlrmchlsmth thanks for your comments will get to them, the model has been merged to HF main, but it is still pending an official release to my knowledge. Hopefully that should happen soon since the last patch release was 2 weeks ago.

I addressed most of the comments in the following commits. I have replaced with get_rope, however I am a little confused with the dynamic implementation. I cant seem to understand how the cache is dynamic; it seems like DynamicNTKScalingRotaryEmbedding indeed has a different alpha, but the cache is statically computed for the "largest it can ever be" and on init. However the point is that it is still static, as "largest" is still determined on init. However, we did this to get a dynamic cache that adjusts whenever longer sequences come in (which may not be known on init).

fabianlim and others added 5 commits January 3, 2025 11:00
Co-authored-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim
Copy link
Author

@tlrmchlsmth ok I have updated the PR again where I have addressed the last remaining comment on the rope behavior, the team is in the process of finalizing the dev checkpoint, when we upload it will ping you again.

@tlrmchlsmth
Copy link
Collaborator

@fabianlim Thanks, sounds good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants