diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 7a7a25d2173d2..60533db230e7d 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -18,8 +18,52 @@ using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; #endif +#ifdef USE_ROCM + #include "quantization/fp8/amd/quant_utils.cuh" +#else + #include "quantization/fp8/nvidia/quant_utils.cuh" +#endif + namespace vllm { +template +struct __align__(16) vec8_t { + scalar_t x, y, z, w, u, v, s, t; + + __device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {} + __device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u, + scalar_t v, scalar_t s, scalar_t t) + : x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {} + + __device__ vec8_t operator*(const vec8_t& other) const { + return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w, + u * other.u, v * other.v, s * other.s, t * other.t); + } + + __device__ vec8_t operator*(const float& scale) const { + return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale, + v * scale, s * scale, t * scale); + } + + __device__ vec8_t operator+(const vec8_t& other) const { + return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w, + u + other.u, v + other.v, s + other.s, t + other.t); + } + + __device__ void operator+=(const vec8_t& other) { + x += other.x; + y += other.y; + z += other.z; + w += other.w; + u += other.u; + v += other.v; + s += other.s; + t += other.t; + } + + __device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; } +}; + // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_kernel( @@ -28,6 +72,49 @@ __global__ void rms_norm_kernel( const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; + + vec8_t v8_variance = {0, 0, 0, 0, 0, 0, 0, 0}; + + vec8_t* vectorized_out = reinterpret_cast*>(out); + vec8_t const* vectorized_in = + reinterpret_cast const*>(input); + vec8_t const* vectorized_weight = + reinterpret_cast const*>(weight); + const int vec_hidden_size = hidden_size >> 3; + + // Compute variance. Be careful, hidden_size should multiple of 4. + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + vec8_t x = vectorized_in[blockIdx.x * vec_hidden_size + idx]; + v8_variance += x * x; + } + float v8_variance_sum = v8_variance.sum(); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float variance = + BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + vec8_t v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx]; + vec8_t v8_w = vectorized_weight[idx]; + vectorized_out[blockIdx.x * vec_hidden_size + idx] = + v8_in * s_variance * v8_w; + } +} + +template +__global__ void scaled_rms_norm_kernel( + hip_fp8* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* scale, const float epsilon, const int num_tokens, + const int hidden_size, const int hidden_size_padded) { + __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { @@ -46,8 +133,9 @@ __global__ void rms_norm_kernel( for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + x = (x * s_variance) * (float)weight[idx] / (*scale); + + out[blockIdx.x * hidden_size_padded + idx] = hip_fp8(x); } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..83369664606d2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,13 @@ +{ + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 4db847029566f..8345caebedca2 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -39,6 +39,7 @@ "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py new file mode 100644 index 0000000000000..a2bc0a8c792a0 --- /dev/null +++ b/vllm/model_executor/models/grok1.py @@ -0,0 +1,509 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Grok1 model.""" +import os +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Grok1Config + +from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers + +attn_output_multiplier = 0.08838834764831845 +output_multiplier_scale = 0.5773502691896257 +max_attn_val = 30.0 +reduce_conversion_kernel: bool = os.getenv("VLLM_FP8_REDUCE_CONV", '0') == "1" + + +class Grok1MoE(nn.Module): + """A tensor-parallel MoE implementation for Grok1 that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class Grok1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Grok1DecoderLayer(nn.Module): + + def __init__( + self, + config: Grok1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.attn = Grok1Attention(hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.moe_block = Grok1MoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block") + + self.pre_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.pre_attn_norm( + hidden_states, self.attn.qkv_proj.activation_scaling_factor + ) if reduce_conversion_kernel else self.pre_attn_norm( + hidden_states) + else: + hidden_states, residual = self.pre_attn_norm( + hidden_states, self.attn.qkv_proj.activation_scaling_factor, + residual) if reduce_conversion_kernel else self.pre_attn_norm( + hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states = self.post_attn_norm(hidden_states) + + ### fused_moe performance bad + hidden_states, residual = self.pre_moe_norm(hidden_states, residual) + + hidden_states = self.moe_block(hidden_states) + + hidden_states = self.post_moe_norm(hidden_states) + return hidden_states, residual + + +class Grok1Model(nn.Module): + + def __init__( + self, + config: Grok1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embedding_multiplier_scale = config.embedding_multiplier_scale + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Grok1DecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states * self.embedding_multiplier_scale + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Grok1ForCausalLM(nn.Module, SupportsLoRA): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: Grok1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = Grok1Model(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + output_multiplier_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", + ckpt_down_proj_name="linear_1", + ckpt_up_proj_name="linear_v", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + + if "norm.scale" in name: + name = name.replace("scale", "weight") + + if "lm_head" in name and self.config.tie_word_embeddings: + continue + + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 13fcf6b918603..199c1caa2164e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -18,11 +18,11 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, - GraniteConfig, InternVLChatConfig, - JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, RWConfig, - UltravoxConfig) + GraniteConfig, Grok1Config, + InternVLChatConfig, JAISConfig, + MedusaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + RWConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file @@ -52,6 +52,7 @@ # Granite can be removed from here once we have upgraded to # transformers 4.45+ "granite": GraniteConfig, + "grok-1": Grok1Config, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8381c5227584e..70025ad5c6bb9 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,6 +7,7 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.granite import GraniteConfig +from vllm.transformers_utils.configs.grok1 import Grok1Config from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -31,4 +32,5 @@ # Granite can be removed from here once we have upgraded to # transformers 4.45+ "GraniteConfig", + "Grok1Config", ] diff --git a/vllm/transformers_utils/configs/grok1.py b/vllm/transformers_utils/configs/grok1.py new file mode 100644 index 0000000000000..3eed2744318b1 --- /dev/null +++ b/vllm/transformers_utils/configs/grok1.py @@ -0,0 +1,60 @@ +from transformers.configuration_utils import PretrainedConfig + + +class Grok1Config(PretrainedConfig): + model_type = "grok-1" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__(self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=32768, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + attn_output_multiplier=1.0, + max_attn_value=1.0, + max_position_embeddings=4096, + embedding_multiplier_scale: float = 1.0, + output_multiplier_scale: float = 1.0, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + num_experts_per_tok=2, + num_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs): + self.vocab_size = vocab_size + self.attn_output_multiplier = attn_output_multiplier + self.max_attn_value = max_attn_value + self.max_position_embeddings = max_position_embeddings + self.embedding_multiplier_scale = embedding_multiplier_scale + self.output_multiplier_scale = output_multiplier_scale + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + )