Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Grok-1 #181

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,99 @@ using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif

#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif

namespace vllm {

template <typename scalar_t>
struct __align__(16) vec8_t {
scalar_t x, y, z, w, u, v, s, t;

__device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {}
__device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u, scalar_t v, scalar_t s, scalar_t t) : x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t){}

__device__ vec8_t operator*(const vec8_t& other) const {
return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w,
u * other.u, v * other.v, s * other.s, t * other.t);
}

__device__ vec8_t operator*(const float& scale) const {
return vec8_t(x * scale, y * scale, z * scale, w * scale,
u * scale, v * scale, s * scale, t * scale);
}

__device__ vec8_t operator+(const vec8_t& other) const {
return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w,
u + other.u, v + other.v, s + other.s, t + other.t);
}

__device__ void operator+=(const vec8_t& other) {
x += other.x;
y += other.y;
z += other.z;
w += other.w;
u += other.u;
v += other.v;
s += other.s;
t += other.t;
}

__device__ scalar_t sum() const {
return x + y + z + w + u + v + s + t;
}
};

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;

vec8_t<scalar_t> v8_variance = {0, 0, 0, 0, 0, 0, 0, 0};

vec8_t<scalar_t>* vectorized_out = reinterpret_cast<vec8_t<scalar_t>*>(out);
vec8_t<scalar_t> const* vectorized_in = reinterpret_cast<vec8_t<scalar_t> const*>(input);
vec8_t<scalar_t> const* vectorized_weight = reinterpret_cast<vec8_t<scalar_t> const*>(weight);
const int vec_hidden_size = hidden_size >> 3;

// Compute variance. Be carefull, hidden_size should multiple of 4.
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> x = vectorized_in[blockIdx.x * vec_hidden_size + idx];
v8_variance += x * x;
}
float v8_variance_sum = v8_variance.sum();

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float variance = BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx];
vec8_t<scalar_t> v8_w = vectorized_weight[idx];
vectorized_out[blockIdx.x * vec_hidden_size + idx] = v8_in * s_variance * v8_w;
}
}

template <typename scalar_t>
__global__ void scaled_rms_norm_kernel(
hip_fp8* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* scale,
const float epsilon, const int num_tokens, const int hidden_size,
const int hidden_size_padded) {
__shared__ float s_variance;
float variance = 0.0f;

Expand All @@ -46,8 +130,9 @@ __global__ void rms_norm_kernel(

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
x = (x * s_variance) * (float)weight[idx] / (*scale);

out[blockIdx.x * hidden_size_padded + idx] = hip_fp8(x);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
"Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
Expand Down
Loading
Loading