diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 3edc578e519bf..3cb1938953506 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -25,12 +25,15 @@ from typing import Iterable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -46,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.scalar_type import scalar_types from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA @@ -54,6 +58,45 @@ logger = logging.getLogger(__name__) +class MixtralMLP(nn.Module): + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + self.w2 = ReplicatedLinear(self.ffn_dim, + self.hidden_dim, + bias=False, + quant_config=quant_config) + self.w3 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + class MixtralMoE(nn.Module): """A tensor-parallel MoE implementation for Mixtral that shards each expert across all ranks. @@ -69,6 +112,7 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, + use_fused_moe: bool, params_dtype: Optional[torch.dtype] = torch.float16, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, @@ -87,28 +131,68 @@ def __init__( prefix=f"{prefix}.gate", ) - self.experts = FusedMoE( - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - prefix=f"{prefix}.experts", - ) + self.use_fused_moe = use_fused_moe + if self.use_fused_moe: + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", + ) + else: + self.top_k = top_k + self.num_experts = num_experts + self.experts = nn.ModuleList([ + MixtralMLP(num_experts, + hidden_size, + intermediate_size, + quant_config=quant_config) + for idx in range(num_experts) + ]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape, orig_type = hidden_states.shape, hidden_states.dtype - hidden_states = hidden_states.view(-1, self.hidden_size).to( - self.params_dtype) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(orig_shape).to(orig_type) + if self.use_fused_moe: + hidden_states = hidden_states.view(-1, self.hidden_size).to( + self.params_dtype) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape).to(orig_type) + else: + hidden_states = hidden_states.view(-1, self.hidden_size) + router_logits, _ = self.gate(hidden_states.half()) + routing_weights = F.softmax(router_logits, + dim=1, + dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum( + dim=-1, keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + orig_shape).to(orig_type) class MixtralAttention(nn.Module): @@ -197,6 +281,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, + use_fused_moe: bool, config: MixtralConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -221,6 +306,7 @@ def __init__( top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, + use_fused_moe=use_fused_moe, quant_config=quant_config, tp_size=get_tensor_model_parallel_world_size(), params_dtype=torch.float16, @@ -264,6 +350,7 @@ class MixtralModel(nn.Module): def __init__( self, + use_fused_moe: bool, config: MixtralConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -286,7 +373,11 @@ def __init__( self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix + use_fused_moe, + config, + cache_config, + quant_config=quant_config, + prefix=prefix, ), prefix=f"{prefix}.layers", ) @@ -358,10 +449,13 @@ def __init__( lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() - + # TODO keep the fused mixtral_quant codepath around as long as we don't + # support all quant_types + self.use_fused_moe = (quant_config.quant_type == scalar_types.uint4b8) self.config = config self.lora_config = lora_config - self.model = MixtralModel(config, + self.model = MixtralModel(self.use_fused_moe, + config, cache_config, quant_config, lora_config=lora_config, @@ -436,65 +530,98 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts, - ) + if self.use_fused_moe: - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - # Skip layers on other devices. name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + param = params_dict[name] weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - is_quantized=True, - ) + weight_loader(param, loaded_weight, shard_id) break else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + # Skip layers on other devices. + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + is_quantized=True, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + else: + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: + + if ("block_sparse_moe.experts." in name + and name not in params_dict): continue param = params_dict[name]