Skip to content

Commit

Permalink
Add support RoPE in MPT
Browse files Browse the repository at this point in the history
Signed-off-by: Kazuki OIKAWA <[email protected]>
  • Loading branch information
kazuki committed Nov 19, 2024
1 parent 8d6dcc7 commit ca8d3e8
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
else:
self.total_num_kv_heads = self.total_num_heads
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
assert config.attn_config["alibi"] or config.attn_config["rope"]

# pylint: disable=invalid-name
self.Wqkv = QKVParallelLinear(
Expand Down Expand Up @@ -102,13 +103,24 @@ def __init__(
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads,
self.alibi_bias_max)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()

alibi_slopes = None
self.rotary_emb = None
if config.attn_config["alibi"]:
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads,
self.alibi_bias_max)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
elif config.attn_config["rope"]:
self.rotary_emb = get_rope(
config.d_model // config.n_heads,
rotary_dim=config.d_model // config.n_heads,
max_position=config.max_seq_len,
base=config.attn_config["rope_theta"],
)

self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5
Expand All @@ -127,14 +139,15 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
del position_ids # unused.
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)
if self.rotary_emb:
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
return output
Expand Down

0 comments on commit ca8d3e8

Please sign in to comment.