Skip to content

Commit

Permalink
Unfused codepath for non-supported quant_types
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Aug 26, 2024
1 parent fd4bb21 commit 7956a69
Showing 1 changed file with 191 additions and 64 deletions.
255 changes: 191 additions & 64 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
from typing import Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Expand All @@ -46,6 +49,7 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.scalar_type import scalar_types
from vllm.sequence import IntermediateTensors, SamplerOutput

from .interfaces import SupportsLoRA
Expand All @@ -54,6 +58,45 @@
logger = logging.getLogger(__name__)


class MixtralMLP(nn.Module):

def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size

self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
quant_config=quant_config)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
quant_config=quant_config)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
quant_config=quant_config)

# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states


class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Expand All @@ -69,6 +112,7 @@ def __init__(
top_k: int,
hidden_size: int,
intermediate_size: int,
use_fused_moe: bool,
params_dtype: Optional[torch.dtype] = torch.float16,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
Expand All @@ -87,28 +131,68 @@ def __init__(
prefix=f"{prefix}.gate",
)

self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts",
)
self.use_fused_moe = use_fused_moe
if self.use_fused_moe:
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts",
)
else:
self.top_k = top_k
self.num_experts = num_experts
self.experts = nn.ModuleList([
MixtralMLP(num_experts,
hidden_size,
intermediate_size,
quant_config=quant_config)
for idx in range(num_experts)
])

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape, orig_type = hidden_states.shape, hidden_states.dtype
hidden_states = hidden_states.view(-1, self.hidden_size).to(
self.params_dtype)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape).to(orig_type)
if self.use_fused_moe:
hidden_states = hidden_states.view(-1, self.hidden_size).to(
self.params_dtype)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape).to(orig_type)
else:
hidden_states = hidden_states.view(-1, self.hidden_size)
router_logits, _ = self.gate(hidden_states.half())
routing_weights = F.softmax(router_logits,
dim=1,
dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

final_hidden_states = None
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(
dim=-1, keepdim=True)

current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)

return tensor_model_parallel_all_reduce(final_hidden_states).view(
orig_shape).to(orig_type)


class MixtralAttention(nn.Module):
Expand Down Expand Up @@ -197,6 +281,7 @@ class MixtralDecoderLayer(nn.Module):

def __init__(
self,
use_fused_moe: bool,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
Expand All @@ -221,6 +306,7 @@ def __init__(
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
use_fused_moe=use_fused_moe,
quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(),
params_dtype=torch.float16,
Expand Down Expand Up @@ -264,6 +350,7 @@ class MixtralModel(nn.Module):

def __init__(
self,
use_fused_moe: bool,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
Expand All @@ -286,7 +373,11 @@ def __init__(
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
use_fused_moe,
config,
cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
Expand Down Expand Up @@ -358,10 +449,13 @@ def __init__(
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()

# TODO keep the fused mixtral_quant codepath around as long as we don't
# support all quant_types
self.use_fused_moe = (quant_config.quant_type == scalar_types.uint4b8)
self.config = config
self.lora_config = lora_config
self.model = MixtralModel(config,
self.model = MixtralModel(self.use_fused_moe,
config,
cache_config,
quant_config,
lora_config=lora_config,
Expand Down Expand Up @@ -436,65 +530,98 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("qkv_proj", "v_proj", "v"),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts,
)
if self.use_fused_moe:

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts,
)

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# Skip layers on other devices.
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
is_quantized=True,
)
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Skip layers on other devices.
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
is_quantized=True,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
else:

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:

if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue

param = params_dict[name]
Expand Down

0 comments on commit 7956a69

Please sign in to comment.