Skip to content

Commit

Permalink
replace with get_rope
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jan 3, 2025
1 parent a1de811 commit 660d5a2
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
Expand Down Expand Up @@ -161,10 +161,10 @@ def __init__(
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

self.rotary_emb = RotaryEmbedding(
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=config.attn_rotary_emb,
max_position_embeddings=max_position_embeddings,
max_position=max_position_embeddings,
base=rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(), # see impl of get_rope
Expand Down

0 comments on commit 660d5a2

Please sign in to comment.