From fa23e51cec112fce860ac896072bd6732f558304 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 24 Sep 2024 14:50:46 -0400 Subject: [PATCH] working kernel --- csrc/moe/marlin_kernels/marlin_moe_kernel.h | 31 ++++----- csrc/moe/marlin_moe_ops.cu | 13 +--- csrc/quantization/gptq_marlin/gptq_marlin.cu | 2 +- tests/kernels/test_awq_marlin.py | 66 +++++++++----------- 4 files changed, 49 insertions(+), 63 deletions(-) diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h index d44294262d830..808fcdae7a7c6 100644 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel.h +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -38,7 +38,7 @@ using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; // quantization scales -using FragZP = Vec; +using FragZP = Vec; // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. @@ -230,13 +230,6 @@ __device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { frag_b[1] = __hsub2(frag_b[1], zp); } -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - // Same as above, but for act_order (each K is multiplied individually) __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, FragS& frag_s_3, FragS& frag_s_4, int i) { @@ -252,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { @@ -440,6 +440,7 @@ __device__ void MarlinMoESingle( : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; @@ -536,12 +537,6 @@ __device__ void MarlinMoESingle( int sh_num_groups = -1; constexpr int sh_max_num_groups = 32; - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. int4* sh_a = sh; @@ -674,6 +669,7 @@ __device__ void MarlinMoESingle( for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } + B_ptr[i] += b_gl_rd_delta_o; } @@ -990,6 +986,10 @@ __device__ void MarlinMoESingle( FragB frag_b0 = dequant(b_quant_0); FragB frag_b1 = dequant(b_quant_1); + // Apply zero-point to frag_b0 + if constexpr (has_zp) { + sub_zp(frag_b0, frag_zp[j], 0); + } // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -1193,6 +1193,7 @@ __device__ void MarlinMoESingle( ((half2*)sh)[idx] = res; }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1277,6 +1278,7 @@ __device__ void MarlinMoESingle( // ensure all shared memory accesses are static. Note that both pipelines // have even length meaning that the next iteration will always start at // index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { #pragma unroll @@ -1420,6 +1422,7 @@ __device__ void MarlinMoESingle( s_gl_rd = s_sh_stride * slice_col + threadIdx.x; zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; } + start_pipes(); } } diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 7f7017deac401..7052510086288 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -436,16 +436,9 @@ void marlin_mm_moe(const void* A, const void* B, void* C, int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; - const int4* s_ptr = - (const int4*)s + - (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * - prob_n / 8) * - expert_idx; + const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; const int4* zp_ptr = - (const int4*)zp + - (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * - prob_n / 4) * - expert_idx; + (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; @@ -570,7 +563,7 @@ torch::Tensor marlin_gemm_moe( "b_zeros dim 1 = ", b_zeros.size(1), " is not num_groups = ", num_groups); TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, - "b_zeros dim 2 = ", b_scales.size(2), + "b_zeros dim 2 = ", b_zeros.size(2), " is not size_n / pack_factor = ", size_n / pack_factor); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 9b4a6a515107d..f943185bab7f0 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -2258,7 +2258,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", b_scales.size(1), + "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n / pack_factor = ", size_n / pack_factor); } diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index e60b7d976c8f6..e408636bdb2d2 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -8,23 +8,24 @@ import torch from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.scalar_type import scalar_types from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - awq_marlin_quantize -) + awq_marlin_quantize) +from vllm.scalar_type import scalar_types + def stack_and_dev(tensors: List[torch.Tensor]): dev = tensors[0].device return torch.stack(tensors, dim=0).to(dev) + def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( torch.abs(output_ref)) + def torch_moe(a, w1, w2, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -55,7 +56,7 @@ def torch_moe_single(a, w, score, topk): out[mask] = a[mask] @ w[i].transpose(0, 1) return (out.view(B, -1, w.shape[1])).sum(dim=1) -@pytest.mark.skip("TODO") + @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) @@ -77,8 +78,7 @@ def test_fused_marlin_moe_awq( if topk > e: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -123,15 +123,6 @@ def test_fused_marlin_moe_awq( score = torch.randn((m, e), device="cuda", dtype=dtype) topk_weights, topk_ids = fused_topk(a, score, topk, False) - - triton_output = fused_moe( - a, - w_ref1.transpose(1, 2).contiguous(), - w_ref2.transpose(1, 2).contiguous(), - score, - topk, - renormalize=False, - ) marlin_output = fused_marlin_moe( a, qweight1, @@ -146,25 +137,26 @@ def test_fused_marlin_moe_awq( num_bits=num_bits, ) - assert compute_max_diff(marlin_output, triton_output) < 4e-2 + torch_output = torch_moe( + a, + w_ref1.transpose(1, 2), + w_ref2.transpose(1, 2), + score, + topk, + ) + + assert compute_max_diff(marlin_output, torch_output) < 4e-2 # @pytest.mark.skip("This test is here for the sake of debugging, " # "don't run it in automated tests.") -# @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) -# @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) -# @pytest.mark.parametrize("k", [128, 1024, 512]) -# @pytest.mark.parametrize("e", [4, 8, 64]) -# @pytest.mark.parametrize("topk", [2, 6]) -# @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -# @pytest.mark.parametrize("num_bits", [4, 8]) -@pytest.mark.parametrize("m", [1]) -@pytest.mark.parametrize("n", [128]) -@pytest.mark.parametrize("k", [128]) -@pytest.mark.parametrize("e", [4]) -@pytest.mark.parametrize("topk", [2]) -@pytest.mark.parametrize("group_size", [-1]) -@pytest.mark.parametrize("num_bits", [4]) +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("num_bits", [4, 8]) def test_single_marlin_moe_multiply_awq( m: int, n: int, @@ -174,11 +166,12 @@ def test_single_marlin_moe_multiply_awq( group_size: int, num_bits: int, ): + torch.manual_seed(7) + if topk > e: return - quant_type = (scalar_types.uint4b8 - if num_bits == 4 else scalar_types.uint8b128) + quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 @@ -198,11 +191,8 @@ def test_single_marlin_moe_multiply_awq( w_ref = stack_and_dev(w_ref_l) qweight = stack_and_dev(qweights_l).contiguous() - scales = stack_and_dev(scales_l) - zp = stack_and_dev(zp_l) - - print(scales.dtype) - print(zp.dtype) + scales = stack_and_dev(scales_l).contiguous() + zp = stack_and_dev(zp_l).contiguous() score = torch.randn((m, e), device="cuda", dtype=dtype)