From 963f312aa3a4726e53c2dd8d563c94cd62a8069e Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:47:13 -0500 Subject: [PATCH 1/5] Restoring deleted .buildkite/test-template.j2 (#177) Co-authored-by: Alexei Ivanov --- .buildkite/test-template.j2 | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 .buildkite/test-template.j2 diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 new file mode 100644 index 0000000000000..e7b24268ba398 --- /dev/null +++ b/.buildkite/test-template.j2 @@ -0,0 +1,38 @@ +{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %} +{% set docker_image_amd = "rocm/vllm-ci:$BUILDKITE_COMMIT" %} +{% set default_working_dir = "vllm/tests" %} +{% set hf_home = "/root/.cache/huggingface" %} + +steps: + - label: ":docker: build image" + depends_on: ~ + commands: + - "docker build --build-arg max_jobs=16 --tag {{ docker_image_amd }} -f Dockerfile.rocm --progress plain ." + - "docker push {{ docker_image_amd }}" + key: "amd-build" + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 + agents: + queue: amd + +{% for step in steps %} +{% if step.mirror_hardwares and "amd" in step.mirror_hardwares %} + - label: "AMD: {{ step.label }}" + depends_on: + - "amd-build" + agents: + queue: amd + commands: + - bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" && ")) | safe }}" + env: + DOCKER_BUILDKIT: "1" + priority: 100 + soft_fail: true +{% endif %} +{% endfor %} From 5cf1c75ebc5bbd008bf5cf0bd6ce7dca4b1398a9 Mon Sep 17 00:00:00 2001 From: Joe Shajrawi <17753158+shajrawi@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:50:21 -0500 Subject: [PATCH 2/5] Support commandr on ROCm (#180) Disable torch compile --- vllm/model_executor/models/commandr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 649dc798d22dc..ba12fd2044953 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -50,7 +50,6 @@ from .interfaces import SupportsLoRA -@torch.compile def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) From dc948ab3716f98b25633314d9e02f8f145e730af Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 11 Sep 2024 17:36:43 -0400 Subject: [PATCH 3/5] Correct type hint (#173) --- vllm/engine/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 94271c4a93151..49ef3fcbc036a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1443,7 +1443,7 @@ def _process_model_outputs(self, return None def _advance_to_next_step( - self, output: List[SamplerOutput], + self, output: SamplerOutput, seq_group_metadata_list: List[SequenceGroupMetadata], scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: """Given model output from a single run, append the tokens to the From 78e6e0f5a875bf202edae5df02305a2a3f9e37f9 Mon Sep 17 00:00:00 2001 From: sanyalington Date: Thu, 12 Sep 2024 06:06:28 +0530 Subject: [PATCH 4/5] update custom PA kernel with support for fp8 kv cache dtype (#87) * update custom PA kernel with support for fp8 kv cache dtype; change custom PA partition size to 512 to prefer throughput scenarios at cost of latency * Fix lint * Fix BF16 with FP8 KV cache (scaled conversion incorrectly done in fp16) --------- Co-authored-by: Matthew Wong --- csrc/custom/custom_ops.h | 3 +- .../custom/paged_attention/attention_ll4mi.cu | 253 +++++++++++------- csrc/custom/torch_bindings.cpp | 3 +- csrc/quantization/fp8/amd/quant_utils.cuh | 2 - vllm/_custom_ops.py | 12 +- vllm/attention/ops/paged_attn.py | 32 +-- 6 files changed, 179 insertions(+), 126 deletions(-) diff --git a/csrc/custom/custom_ops.h b/csrc/custom/custom_ops.h index f6ea892b2ffa5..f8ab5ee5544df 100644 --- a/csrc/custom/custom_ops.h +++ b/csrc/custom/custom_ops.h @@ -18,4 +18,5 @@ void paged_attention_custom(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& context_lens, int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); \ No newline at end of file + const std::string& kv_cache_dtype, double k_scale, + double v_scale); diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/custom/paged_attention/attention_ll4mi.cu index e78dce4c30de3..b38ec30dfcdc1 100644 --- a/csrc/custom/paged_attention/attention_ll4mi.cu +++ b/csrc/custom/paged_attention/attention_ll4mi.cu @@ -6,6 +6,8 @@ #include "cuda_compat.h" #include +#include "../../attention/dtype_fp8.cuh" +#include "../../quantization/fp8/amd/quant_utils.cuh" #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ defined(__gfx941__) || defined(__gfx942__)) @@ -45,6 +47,8 @@ typedef struct _B16x8 { _B16x4 xy[2]; } _B16x8; +using _B8x8 = uint2; + ////// Non temporal load stores /////// #if 1 @@ -232,18 +236,40 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } } +template +__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, + const float scale) { + union alignas(16) { + uint4 u4; + _B16x8 u16x8; + vllm::bf16_8_t b16x8; + } tmp; + if constexpr (std::is_same::value) { + tmp.u4 = vllm::fp8::scaled_convert(input, scale); + return tmp.u16x8; + } else if constexpr (std::is_same::value) { + tmp.b16x8 = vllm::fp8::scaled_convert( + input, scale); + return tmp.u16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + /////////////////////////////////////// // grid (num_seqs, num_partitions,num_heads/gqa_ratio) // block (partition size) -template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -256,10 +282,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - #if 0 - scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] - #endif - int max_ctx_blocks) { + int max_ctx_blocks, float k_scale, float v_scale) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -286,12 +309,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( constexpr int x = 16 / sizeof(scalar_t); constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; + _B8x8 Klocalb8[KHELOOP]; constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; // v head_size dimension is distributed across lanes constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 // 8xtokens _B16x8 Vlocal[VHELOOP][VTLOOP]; + _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; #pragma unroll @@ -343,6 +368,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; vphysical_blocks[b] = block_table[vblock_idx_ctx]; } + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; @@ -362,17 +388,29 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( Qlocal[QHLOOP - 1].xy[1] = {0}; } - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + - wg_start_kv_head_idx * kv_head_stride; + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; const int physical_block_offset = local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset // is already cast as _H8 - - const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } + } else { + constexpr int X = 16 / sizeof(cache_t); + const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + const int head_elem = d * 8; + const int offset1 = head_elem / X; + const int offset2 = head_elem % X; + const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2; + Klocalb8[d] = *reinterpret_cast(k_ptr3); + } } float alibi_slope[QHLOOP]; @@ -386,28 +424,64 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } - const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); - // iterate over each v block + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride - const int64_t vphysical_block_number = - static_cast(vphysical_blocks[b]); - const _B16x8* v_ptrh8b = - v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; - // iterate over each head elem (within head_size) + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) #pragma unroll - for (int h = 0; h < VHELOOP; h++) { - const int head_size_elem = h * WARP_SIZE + laneid; - const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; - // iterate over all velems within block + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block #pragma unroll - for (int d = 0; d < BLOCK_SIZE / 8; d++) { - Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } } } + } else { + const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B8x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + const _B8x8 Vlocalb8 = v_ptrh8be[d]; + Vlocal[h][b * BLOCK_SIZE / 8 + d] = + scaled_convert_b8x8(Vlocalb8, v_scale); + } + } + } + } + + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = + scaled_convert_b8x8(Klocalb8[d], k_scale); + } } #pragma unroll @@ -599,6 +673,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } else { // warp in context + // iterate across heads #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { @@ -684,22 +759,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } - - #if 0 - const int num_seqs = gridDim.x; - const int global_token4id = global_token_idx/4; - #pragma unroll - for (int t=0;t<4;t++) { - #pragma unroll - for (int h=0;h __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -894,10 +955,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - #if 0 - scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] - #endif - int max_ctx_blocks) { + int max_ctx_blocks, float k_scale, float v_scale) { UNREACHABLE_CODE } @@ -919,26 +977,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ + paged_attention_ll4mi_QKV_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale, v_scale); -template +template void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, - int max_context_len, -#if 0 - torch::Tensor& qk_out, - torch::Tensor& softmax_out, -#endif - const c10::optional& alibi_slopes) { - + int max_context_len, const c10::optional& alibi_slopes, + float k_scale, float v_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -958,14 +1014,10 @@ void paged_attention_custom_launcher( float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); -#if 0 - T* qk_out_ptr = reinterpret_cast(qk_out.data_ptr()); - T* softmax_out_ptr = reinterpret_cast(softmax_out.data_ptr()); -#endif const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_num_partitions = @@ -1052,32 +1104,32 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ - alibi_slopes); +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes, k_scale, v_scale); -#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \ case 16: \ - CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ case 32: \ - CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ switch (head_size) { \ case 64: \ - CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ break; \ case 128: \ - CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ break; \ default: \ TORCH_CHECK(false, "Unsupported head size: ", head_size); \ @@ -1100,15 +1152,30 @@ void paged_attention_custom( torch::Tensor& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { - assert(kv_cache_dtype == "auto"); + const std::string& kv_cache_dtype, double k_scale, double v_scale) { const int head_size = query.size(2); - if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, + vllm::Fp8KVCacheDataType::kAuto); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, + vllm::Fp8KVCacheDataType::kAuto); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); } } diff --git a/csrc/custom/torch_bindings.cpp b/csrc/custom/torch_bindings.cpp index 73f804f6f41df..dc26ac5e57204 100644 --- a/csrc/custom/torch_bindings.cpp +++ b/csrc/custom/torch_bindings.cpp @@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, custom_ops) { " Tensor context_lens, int block_size," " int max_context_len," " Tensor? alibi_slopes," - " str kv_cache_dtype) -> ()"); + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); custom_ops.impl("paged_attention_custom", torch::kCUDA, &paged_attention_custom); custom_ops.def( diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index b9887d685d47c..4b77817f2df82 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -378,8 +378,6 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { return __float2bfloat16(f * scale); } -using __nv_bfloat162 = __hip_bfloat162; - // fp8x2 -> __nv_bfloat162 template <> __inline__ __device__ __nv_bfloat162 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 09a2cdc4174be..7488da9ad5e91 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -684,13 +684,13 @@ def paged_attention_custom( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, + k_scale: int, + v_scale: int, ) -> None: - torch.ops._custom_C.paged_attention_custom(out, exp_sum, max_logits, - tmp_out, query, key_cache, - value_cache, num_kv_heads, - scale, block_tables, seq_lens, - block_size, max_seq_len, - alibi_slopes, kv_cache_dtype) + torch.ops._custom_C.paged_attention_custom( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale) def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 9777c88e0ef15..b055319c8a5ac 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -16,7 +16,7 @@ # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. _PARTITION_SIZE_V1V2 = 512 -_PARTITION_SIZE_CUSTOM = 256 +_PARTITION_SIZE_CUSTOM = 512 @dataclass @@ -122,11 +122,8 @@ def forward_decode( num_seqs, num_heads, head_size = query.shape gqa_ratio = num_heads // num_kv_heads use_custom = (custom_attn_available - and (query.dtype == torch.half - or query.dtype == torch.bfloat16) - and (head_size == 128 or head_size == 64) - and (block_size == 16 or block_size == 32) - and kv_cache_dtype == "auto" + and query.dtype in (torch.half, torch.bfloat16) + and head_size in (64, 128) and block_size in (16, 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) if not use_custom: @@ -208,23 +205,12 @@ def forward_decode( blocksparse_head_sliding_step, ) else: - ops.paged_attention_custom( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - ) + ops.paged_attention_custom(output, exp_sums, max_logits, + tmp_output, query, key_cache, + value_cache, num_kv_heads, scale, + block_tables, seq_lens, block_size, + max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale) return output @staticmethod From b1c32735c02baecb8cfbd6e72486891e6b5fe0b1 Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Thu, 12 Sep 2024 22:48:35 +0800 Subject: [PATCH 5/5] Support Grok-1 (#181) * Support Grok-1 * fix lint * fix ruff reported issue --------- Co-authored-by: wunhuang --- csrc/layernorm_kernels.cu | 92 +++- ...MD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json | 13 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/grok1.py | 509 ++++++++++++++++++ vllm/transformers_utils/config.py | 11 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/grok1.py | 60 +++ 7 files changed, 681 insertions(+), 7 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/models/grok1.py create mode 100644 vllm/transformers_utils/configs/grok1.py diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 7a7a25d2173d2..60533db230e7d 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -18,8 +18,52 @@ using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; #endif +#ifdef USE_ROCM + #include "quantization/fp8/amd/quant_utils.cuh" +#else + #include "quantization/fp8/nvidia/quant_utils.cuh" +#endif + namespace vllm { +template +struct __align__(16) vec8_t { + scalar_t x, y, z, w, u, v, s, t; + + __device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {} + __device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u, + scalar_t v, scalar_t s, scalar_t t) + : x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {} + + __device__ vec8_t operator*(const vec8_t& other) const { + return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w, + u * other.u, v * other.v, s * other.s, t * other.t); + } + + __device__ vec8_t operator*(const float& scale) const { + return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale, + v * scale, s * scale, t * scale); + } + + __device__ vec8_t operator+(const vec8_t& other) const { + return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w, + u + other.u, v + other.v, s + other.s, t + other.t); + } + + __device__ void operator+=(const vec8_t& other) { + x += other.x; + y += other.y; + z += other.z; + w += other.w; + u += other.u; + v += other.v; + s += other.s; + t += other.t; + } + + __device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; } +}; + // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_kernel( @@ -28,6 +72,49 @@ __global__ void rms_norm_kernel( const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; + + vec8_t v8_variance = {0, 0, 0, 0, 0, 0, 0, 0}; + + vec8_t* vectorized_out = reinterpret_cast*>(out); + vec8_t const* vectorized_in = + reinterpret_cast const*>(input); + vec8_t const* vectorized_weight = + reinterpret_cast const*>(weight); + const int vec_hidden_size = hidden_size >> 3; + + // Compute variance. Be careful, hidden_size should multiple of 4. + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + vec8_t x = vectorized_in[blockIdx.x * vec_hidden_size + idx]; + v8_variance += x * x; + } + float v8_variance_sum = v8_variance.sum(); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float variance = + BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + vec8_t v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx]; + vec8_t v8_w = vectorized_weight[idx]; + vectorized_out[blockIdx.x * vec_hidden_size + idx] = + v8_in * s_variance * v8_w; + } +} + +template +__global__ void scaled_rms_norm_kernel( + hip_fp8* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* scale, const float epsilon, const int num_tokens, + const int hidden_size, const int hidden_size_padded) { + __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { @@ -46,8 +133,9 @@ __global__ void rms_norm_kernel( for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + x = (x * s_variance) * (float)weight[idx] / (*scale); + + out[blockIdx.x * hidden_size_padded + idx] = hip_fp8(x); } } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..83369664606d2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,13 @@ +{ + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 4db847029566f..8345caebedca2 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -39,6 +39,7 @@ "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py new file mode 100644 index 0000000000000..a2bc0a8c792a0 --- /dev/null +++ b/vllm/model_executor/models/grok1.py @@ -0,0 +1,509 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Grok1 model.""" +import os +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Grok1Config + +from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers + +attn_output_multiplier = 0.08838834764831845 +output_multiplier_scale = 0.5773502691896257 +max_attn_val = 30.0 +reduce_conversion_kernel: bool = os.getenv("VLLM_FP8_REDUCE_CONV", '0') == "1" + + +class Grok1MoE(nn.Module): + """A tensor-parallel MoE implementation for Grok1 that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class Grok1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Grok1DecoderLayer(nn.Module): + + def __init__( + self, + config: Grok1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.attn = Grok1Attention(hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.moe_block = Grok1MoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block") + + self.pre_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.pre_attn_norm( + hidden_states, self.attn.qkv_proj.activation_scaling_factor + ) if reduce_conversion_kernel else self.pre_attn_norm( + hidden_states) + else: + hidden_states, residual = self.pre_attn_norm( + hidden_states, self.attn.qkv_proj.activation_scaling_factor, + residual) if reduce_conversion_kernel else self.pre_attn_norm( + hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states = self.post_attn_norm(hidden_states) + + ### fused_moe performance bad + hidden_states, residual = self.pre_moe_norm(hidden_states, residual) + + hidden_states = self.moe_block(hidden_states) + + hidden_states = self.post_moe_norm(hidden_states) + return hidden_states, residual + + +class Grok1Model(nn.Module): + + def __init__( + self, + config: Grok1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embedding_multiplier_scale = config.embedding_multiplier_scale + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Grok1DecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states * self.embedding_multiplier_scale + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Grok1ForCausalLM(nn.Module, SupportsLoRA): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: Grok1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = Grok1Model(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + output_multiplier_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", + ckpt_down_proj_name="linear_1", + ckpt_up_proj_name="linear_v", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + + if "norm.scale" in name: + name = name.replace("scale", "weight") + + if "lm_head" in name and self.config.tie_word_embeddings: + continue + + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 13fcf6b918603..199c1caa2164e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -18,11 +18,11 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, - GraniteConfig, InternVLChatConfig, - JAISConfig, MedusaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, RWConfig, - UltravoxConfig) + GraniteConfig, Grok1Config, + InternVLChatConfig, JAISConfig, + MedusaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + RWConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file @@ -52,6 +52,7 @@ # Granite can be removed from here once we have upgraded to # transformers 4.45+ "granite": GraniteConfig, + "grok-1": Grok1Config, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8381c5227584e..70025ad5c6bb9 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,6 +7,7 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.granite import GraniteConfig +from vllm.transformers_utils.configs.grok1 import Grok1Config from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -31,4 +32,5 @@ # Granite can be removed from here once we have upgraded to # transformers 4.45+ "GraniteConfig", + "Grok1Config", ] diff --git a/vllm/transformers_utils/configs/grok1.py b/vllm/transformers_utils/configs/grok1.py new file mode 100644 index 0000000000000..3eed2744318b1 --- /dev/null +++ b/vllm/transformers_utils/configs/grok1.py @@ -0,0 +1,60 @@ +from transformers.configuration_utils import PretrainedConfig + + +class Grok1Config(PretrainedConfig): + model_type = "grok-1" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__(self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=32768, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + attn_output_multiplier=1.0, + max_attn_value=1.0, + max_position_embeddings=4096, + embedding_multiplier_scale: float = 1.0, + output_multiplier_scale: float = 1.0, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + num_experts_per_tok=2, + num_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs): + self.vocab_size = vocab_size + self.attn_output_multiplier = attn_output_multiplier + self.max_attn_value = max_attn_value + self.max_position_embeddings = max_position_embeddings + self.embedding_multiplier_scale = embedding_multiplier_scale + self.output_multiplier_scale = output_multiplier_scale + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + )