Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
wunhuang committed Sep 12, 2024
1 parent e9269da commit faf549f
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 130 deletions.
127 changes: 65 additions & 62 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,38 @@ template <typename scalar_t>
struct __align__(16) vec8_t {
scalar_t x, y, z, w, u, v, s, t;

__device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {}
__device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u, scalar_t v, scalar_t s, scalar_t t) : x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t){}

__device__ vec8_t operator*(const vec8_t& other) const {
return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w,
u * other.u, v * other.v, s * other.s, t * other.t);
}
__device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {}
__device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u,
scalar_t v, scalar_t s, scalar_t t)
: x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {}

__device__ vec8_t operator*(const vec8_t& other) const {
return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w,
u * other.u, v * other.v, s * other.s, t * other.t);
}

__device__ vec8_t operator*(const float& scale) const {
return vec8_t(x * scale, y * scale, z * scale, w * scale,
u * scale, v * scale, s * scale, t * scale);
}
__device__ vec8_t operator*(const float& scale) const {
return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale,
v * scale, s * scale, t * scale);
}

__device__ vec8_t operator+(const vec8_t& other) const {
return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w,
u + other.u, v + other.v, s + other.s, t + other.t);
}
__device__ vec8_t operator+(const vec8_t& other) const {
return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w,
u + other.u, v + other.v, s + other.s, t + other.t);
}

__device__ void operator+=(const vec8_t& other) {
x += other.x;
y += other.y;
z += other.z;
w += other.w;
u += other.u;
v += other.v;
s += other.s;
t += other.t;
}
__device__ void operator+=(const vec8_t& other) {
x += other.x;
y += other.y;
z += other.z;
w += other.w;
u += other.u;
v += other.v;
s += other.s;
t += other.t;
}

__device__ scalar_t sum() const {
return x + y + z + w + u + v + s + t;
}
__device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; }
};

// TODO(woosuk): Further optimize this kernel.
Expand All @@ -71,46 +71,49 @@ __global__ void rms_norm_kernel(
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;

vec8_t<scalar_t> v8_variance = {0, 0, 0, 0, 0, 0, 0, 0};

vec8_t<scalar_t>* vectorized_out = reinterpret_cast<vec8_t<scalar_t>*>(out);
vec8_t<scalar_t> const* vectorized_in = reinterpret_cast<vec8_t<scalar_t> const*>(input);
vec8_t<scalar_t> const* vectorized_weight = reinterpret_cast<vec8_t<scalar_t> const*>(weight);
const int vec_hidden_size = hidden_size >> 3;

// Compute variance. Be carefull, hidden_size should multiple of 4.
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> x = vectorized_in[blockIdx.x * vec_hidden_size + idx];
v8_variance += x * x;
}
float v8_variance_sum = v8_variance.sum();

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float variance = BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx];
vec8_t<scalar_t> v8_w = vectorized_weight[idx];
vectorized_out[blockIdx.x * vec_hidden_size + idx] = v8_in * s_variance * v8_w;
}
__shared__ float s_variance;

vec8_t<scalar_t> v8_variance = {0, 0, 0, 0, 0, 0, 0, 0};

vec8_t<scalar_t>* vectorized_out = reinterpret_cast<vec8_t<scalar_t>*>(out);
vec8_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec8_t<scalar_t> const*>(input);
vec8_t<scalar_t> const* vectorized_weight =
reinterpret_cast<vec8_t<scalar_t> const*>(weight);
const int vec_hidden_size = hidden_size >> 3;

// Compute variance. Be carefull, hidden_size should multiple of 4.
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> x = vectorized_in[blockIdx.x * vec_hidden_size + idx];
v8_variance += x * x;
}
float v8_variance_sum = v8_variance.sum();

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float variance =
BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx];
vec8_t<scalar_t> v8_w = vectorized_weight[idx];
vectorized_out[blockIdx.x * vec_hidden_size + idx] =
v8_in * s_variance * v8_w;
}
}

template <typename scalar_t>
__global__ void scaled_rms_norm_kernel(
hip_fp8* __restrict__ out, // [..., hidden_size]
hip_fp8* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* scale,
const float epsilon, const int num_tokens, const int hidden_size,
const int hidden_size_padded) {
const float* scale, const float epsilon, const int num_tokens,
const int hidden_size, const int hidden_size_padded) {
__shared__ float s_variance;
float variance = 0.0f;

Expand Down
79 changes: 43 additions & 36 deletions vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import torch
from torch import nn
from vllm.transformers_utils.configs import Grok1Config

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
Expand All @@ -47,15 +46,16 @@
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Grok1Config

from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers

attn_output_multiplier = 0.08838834764831845
output_multiplier_scale = 0.5773502691896257
max_attn_val = 30.0
reduce_conversion_kernel: bool = True if os.getenv("VLLM_FP8_REDUCE_CONV",
'0') == "1" else False
reduce_conversion_kernel: bool = os.getenv("VLLM_FP8_REDUCE_CONV", '0') == "1"


class Grok1MoE(nn.Module):
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
Expand Down Expand Up @@ -201,27 +201,29 @@ def __init__(
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = Grok1Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.moe_block = Grok1MoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.moe_block")

self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = Grok1Attention(hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.moe_block = Grok1MoE(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.moe_block")

self.pre_attn_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_moe_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_moe_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

def forward(
self,
Expand All @@ -234,10 +236,15 @@ def forward(
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.pre_attn_norm(hidden_states, self.attn.qkv_proj.activation_scaling_factor) if reduce_conversion_kernel else self.pre_attn_norm(hidden_states)
hidden_states = self.pre_attn_norm(
hidden_states, self.attn.qkv_proj.activation_scaling_factor
) if reduce_conversion_kernel else self.pre_attn_norm(
hidden_states)
else:
hidden_states, residual = self.pre_attn_norm(hidden_states, self.attn.qkv_proj.activation_scaling_factor, residual) if reduce_conversion_kernel else self.pre_attn_norm(
hidden_states, residual)
hidden_states, residual = self.pre_attn_norm(
hidden_states, self.attn.qkv_proj.activation_scaling_factor,
residual) if reduce_conversion_kernel else self.pre_attn_norm(
hidden_states, residual)
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
Expand All @@ -254,7 +261,7 @@ def forward(
hidden_states = self.moe_block(hidden_states)

hidden_states = self.post_moe_norm(hidden_states)
return hidden_states, residual
return hidden_states, residual


class Grok1Model(nn.Module):
Expand Down Expand Up @@ -357,10 +364,10 @@ def __init__(
self.lora_config = lora_config

self.model = Grok1Model(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
Expand All @@ -377,7 +384,8 @@ def __init__(
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, output_multiplier_scale)
config.vocab_size,
output_multiplier_scale)
self.sampler = Sampler()

def forward(
Expand Down Expand Up @@ -489,9 +497,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if "norm.scale" in name:
name = name.replace("scale", "weight")

if "lm_head" in name:
if self.config.tie_word_embeddings:
continue
if "lm_head" in name and self.config.tie_word_embeddings:
continue

if name is None:
continue
Expand Down
11 changes: 5 additions & 6 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
# yapf: disable
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
EAGLEConfig, ExaoneConfig,
GraniteConfig, InternVLChatConfig,
JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, RWConfig,
UltravoxConfig,
Grok1Config)
GraniteConfig, Grok1Config,
InternVLChatConfig, JAISConfig,
MedusaConfig, MLPSpeculatorConfig,
MPTConfig, NemotronConfig,
RWConfig, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file

Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.granite import GraniteConfig
from vllm.transformers_utils.configs.grok1 import Grok1Config
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.transformers_utils.configs.grok1 import Grok1Config

__all__ = [
"ChatGLMConfig",
Expand Down
48 changes: 23 additions & 25 deletions vllm/transformers_utils/configs/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,29 @@ class Grok1Config(PretrainedConfig):
model_type = "grok-1"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=32768,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
attn_output_multiplier=1.0,
max_attn_value=1.0,
max_position_embeddings=4096,
embedding_multiplier_scale: float = 1.0,
output_multiplier_scale: float = 1.0,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=True,
num_experts_per_tok=2,
num_experts=8,
output_router_logits=False,
router_aux_loss_coef=0.001,
**kwargs
):
def __init__(self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=32768,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
attn_output_multiplier=1.0,
max_attn_value=1.0,
max_position_embeddings=4096,
embedding_multiplier_scale: float = 1.0,
output_multiplier_scale: float = 1.0,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=True,
num_experts_per_tok=2,
num_experts=8,
output_router_logits=False,
router_aux_loss_coef=0.001,
**kwargs):
self.vocab_size = vocab_size
self.attn_output_multiplier = attn_output_multiplier
self.max_attn_value = max_attn_value
Expand Down

0 comments on commit faf549f

Please sign in to comment.