Skip to content

Commit

Permalink
[Model] Support quantization of PixtralHFTransformer for PixtralHF (v…
Browse files Browse the repository at this point in the history
…llm-project#9921)

Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin authored Nov 5, 2024
1 parent 731aec5 commit a53046b
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 40 deletions.
30 changes: 30 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,33 @@ def get_act_fn(
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
})


def get_act_and_mul_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")

act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
100 changes: 60 additions & 40 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -798,20 +801,24 @@ def __init__(
super().__init__()

assert config.intermediate_size is not None
# TODO: Use quant_config and prefix after optimizing this
self.gate_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.up_proj = nn.Linear(config.hidden_size,
config.intermediate_size,
bias=False)
self.down_proj = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.act = get_act_fn(config.hidden_act)
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
gate_up, _ = self.gate_up_proj(x)
x = self.act_and_mul(gate_up)
x, _ = self.down_proj(x)
return x


class PixtralHFAttention(nn.Module):
Expand All @@ -830,21 +837,21 @@ def __init__(
self.n_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads

self.scale = self.head_dim**-0.5

# TODO: Use quant_config and prefix after optimizing this
self.q_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.k_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.v_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.o_proj = nn.Linear(config.hidden_size,
config.hidden_size,
bias=False)
self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
total_num_heads=self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

def forward(
self,
Expand All @@ -854,36 +861,35 @@ def forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, patches, _ = hidden_states.size()

q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
qkv_states, _ = self.qkv_proj(hidden_states)
q, k, v = qkv_states.chunk(3, dim=-1)

# Transpose q and k to apply HF's Rotary Position Embedding
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, patches, self.n_heads, self.head_dim)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)

if USE_XFORMERS_OPS:
# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.reshape(batch, patches, self.n_heads, self.head_dim)

out = xops.memory_efficient_attention(q,
k,
v,
attn_bias=attention_mask)
else:
v = v.reshape(batch, patches, self.n_heads,
self.head_dim).transpose(1, 2)
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2)

out = out.reshape(batch, patches, self.n_heads * self.head_dim)
out = out.view(batch, patches, self.n_heads * self.head_dim)
attn_output, _ = self.o_proj(out)

return self.o_proj(out)
return attn_output, None


class PixtralHFTransformerBlock(nn.Module):
Expand Down Expand Up @@ -912,9 +918,9 @@ def forward(
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings)
r, _ = self.attention.forward(self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings)
h = hidden_states + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
Expand Down Expand Up @@ -1053,10 +1059,24 @@ def forward(
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = []
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
layer_count = len(self.transformer.layers)

for name, loaded_weight in weights:
# omit layers when num_hidden_layers_override is set
if name.startswith("transformer.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
Expand Down

0 comments on commit a53046b

Please sign in to comment.