diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 73d79dab01b4f..95fd962f4de76 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -30,38 +30,38 @@ template struct __align__(16) vec8_t { scalar_t x, y, z, w, u, v, s, t; - __device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {} - __device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u, scalar_t v, scalar_t s, scalar_t t) : x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t){} - - __device__ vec8_t operator*(const vec8_t& other) const { - return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w, - u * other.u, v * other.v, s * other.s, t * other.t); - } + __device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {} + __device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u, + scalar_t v, scalar_t s, scalar_t t) + : x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {} + + __device__ vec8_t operator*(const vec8_t& other) const { + return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w, + u * other.u, v * other.v, s * other.s, t * other.t); + } - __device__ vec8_t operator*(const float& scale) const { - return vec8_t(x * scale, y * scale, z * scale, w * scale, - u * scale, v * scale, s * scale, t * scale); - } + __device__ vec8_t operator*(const float& scale) const { + return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale, + v * scale, s * scale, t * scale); + } - __device__ vec8_t operator+(const vec8_t& other) const { - return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w, - u + other.u, v + other.v, s + other.s, t + other.t); - } + __device__ vec8_t operator+(const vec8_t& other) const { + return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w, + u + other.u, v + other.v, s + other.s, t + other.t); + } - __device__ void operator+=(const vec8_t& other) { - x += other.x; - y += other.y; - z += other.z; - w += other.w; - u += other.u; - v += other.v; - s += other.s; - t += other.t; - } + __device__ void operator+=(const vec8_t& other) { + x += other.x; + y += other.y; + z += other.z; + w += other.w; + u += other.u; + v += other.v; + s += other.s; + t += other.t; + } - __device__ scalar_t sum() const { - return x + y + z + w + u + v + s + t; - } + __device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; } }; // TODO(woosuk): Further optimize this kernel. @@ -71,46 +71,49 @@ __global__ void rms_norm_kernel( const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { - __shared__ float s_variance; - - vec8_t v8_variance = {0, 0, 0, 0, 0, 0, 0, 0}; - - vec8_t* vectorized_out = reinterpret_cast*>(out); - vec8_t const* vectorized_in = reinterpret_cast const*>(input); - vec8_t const* vectorized_weight = reinterpret_cast const*>(weight); - const int vec_hidden_size = hidden_size >> 3; - - // Compute variance. Be carefull, hidden_size should multiple of 4. - for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { - vec8_t x = vectorized_in[blockIdx.x * vec_hidden_size + idx]; - v8_variance += x * x; - } - float v8_variance_sum = v8_variance.sum(); - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - float variance = BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x); - - if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / hidden_size + epsilon); - } - __syncthreads(); - - for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { - vec8_t v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx]; - vec8_t v8_w = vectorized_weight[idx]; - vectorized_out[blockIdx.x * vec_hidden_size + idx] = v8_in * s_variance * v8_w; - } + __shared__ float s_variance; + + vec8_t v8_variance = {0, 0, 0, 0, 0, 0, 0, 0}; + + vec8_t* vectorized_out = reinterpret_cast*>(out); + vec8_t const* vectorized_in = + reinterpret_cast const*>(input); + vec8_t const* vectorized_weight = + reinterpret_cast const*>(weight); + const int vec_hidden_size = hidden_size >> 3; + + // Compute variance. Be carefull, hidden_size should multiple of 4. + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + vec8_t x = vectorized_in[blockIdx.x * vec_hidden_size + idx]; + v8_variance += x * x; + } + float v8_variance_sum = v8_variance.sum(); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float variance = + BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + vec8_t v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx]; + vec8_t v8_w = vectorized_weight[idx]; + vectorized_out[blockIdx.x * vec_hidden_size + idx] = + v8_in * s_variance * v8_w; + } } template __global__ void scaled_rms_norm_kernel( - hip_fp8* __restrict__ out, // [..., hidden_size] + hip_fp8* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] - const float* scale, - const float epsilon, const int num_tokens, const int hidden_size, - const int hidden_size_padded) { + const float* scale, const float epsilon, const int num_tokens, + const int hidden_size, const int hidden_size_padded) { __shared__ float s_variance; float variance = 0.0f; diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index c266d049bd54c..a2bc0a8c792a0 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -26,7 +26,6 @@ import torch from torch import nn -from vllm.transformers_utils.configs import Grok1Config from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -47,6 +46,7 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Grok1Config from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers @@ -54,8 +54,8 @@ attn_output_multiplier = 0.08838834764831845 output_multiplier_scale = 0.5773502691896257 max_attn_val = 30.0 -reduce_conversion_kernel: bool = True if os.getenv("VLLM_FP8_REDUCE_CONV", - '0') == "1" else False +reduce_conversion_kernel: bool = os.getenv("VLLM_FP8_REDUCE_CONV", '0') == "1" + class Grok1MoE(nn.Module): """A tensor-parallel MoE implementation for Grok1 that shards each expert @@ -201,27 +201,29 @@ def __init__( self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) - self.attn = Grok1Attention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") - self.moe_block = Grok1MoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - prefix=f"{prefix}.moe_block") - - self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = Grok1Attention(hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.moe_block = Grok1MoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block") + + self.pre_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, @@ -234,10 +236,15 @@ def forward( # Self Attention if residual is None: residual = hidden_states - hidden_states = self.pre_attn_norm(hidden_states, self.attn.qkv_proj.activation_scaling_factor) if reduce_conversion_kernel else self.pre_attn_norm(hidden_states) + hidden_states = self.pre_attn_norm( + hidden_states, self.attn.qkv_proj.activation_scaling_factor + ) if reduce_conversion_kernel else self.pre_attn_norm( + hidden_states) else: - hidden_states, residual = self.pre_attn_norm(hidden_states, self.attn.qkv_proj.activation_scaling_factor, residual) if reduce_conversion_kernel else self.pre_attn_norm( - hidden_states, residual) + hidden_states, residual = self.pre_attn_norm( + hidden_states, self.attn.qkv_proj.activation_scaling_factor, + residual) if reduce_conversion_kernel else self.pre_attn_norm( + hidden_states, residual) hidden_states = self.attn( positions=positions, hidden_states=hidden_states, @@ -254,7 +261,7 @@ def forward( hidden_states = self.moe_block(hidden_states) hidden_states = self.post_moe_norm(hidden_states) - return hidden_states, residual + return hidden_states, residual class Grok1Model(nn.Module): @@ -357,10 +364,10 @@ def __init__( self.lora_config = lora_config self.model = Grok1Model(config, - cache_config, - quant_config, - lora_config=lora_config, - prefix="model") + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -377,7 +384,8 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, output_multiplier_scale) + config.vocab_size, + output_multiplier_scale) self.sampler = Sampler() def forward( @@ -489,9 +497,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if "norm.scale" in name: name = name.replace("scale", "weight") - if "lm_head" in name: - if self.config.tie_word_embeddings: - continue + if "lm_head" in name and self.config.tie_word_embeddings: + continue if name is None: continue diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1809ccee4a6ab..199c1caa2164e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -18,12 +18,11 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, - GraniteConfig, InternVLChatConfig, - JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, RWConfig, - UltravoxConfig, - Grok1Config) + GraniteConfig, Grok1Config, + InternVLChatConfig, JAISConfig, + MedusaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + RWConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 32b16773299ec..70025ad5c6bb9 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,6 +7,7 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.granite import GraniteConfig +from vllm.transformers_utils.configs.grok1 import Grok1Config from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -14,7 +15,6 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig -from vllm.transformers_utils.configs.grok1 import Grok1Config __all__ = [ "ChatGLMConfig", diff --git a/vllm/transformers_utils/configs/grok1.py b/vllm/transformers_utils/configs/grok1.py index 6d6a475788ae8..3eed2744318b1 100644 --- a/vllm/transformers_utils/configs/grok1.py +++ b/vllm/transformers_utils/configs/grok1.py @@ -5,31 +5,29 @@ class Grok1Config(PretrainedConfig): model_type = "grok-1" keys_to_ignore_at_inference = ["past_key_values"] - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=32768, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - attn_output_multiplier=1.0, - max_attn_value=1.0, - max_position_embeddings=4096, - embedding_multiplier_scale: float = 1.0, - output_multiplier_scale: float = 1.0, - rms_norm_eps=1e-5, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=True, - num_experts_per_tok=2, - num_experts=8, - output_router_logits=False, - router_aux_loss_coef=0.001, - **kwargs - ): + def __init__(self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=32768, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + attn_output_multiplier=1.0, + max_attn_value=1.0, + max_position_embeddings=4096, + embedding_multiplier_scale: float = 1.0, + output_multiplier_scale: float = 1.0, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + num_experts_per_tok=2, + num_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs): self.vocab_size = vocab_size self.attn_output_multiplier = attn_output_multiplier self.max_attn_value = max_attn_value