Skip to content

Commit

Permalink
working kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Sep 24, 2024
1 parent 98ec9b6 commit fa23e51
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 63 deletions.
31 changes: 17 additions & 14 deletions csrc/moe/marlin_kernels/marlin_moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales
using FragZP = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;

// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
Expand Down Expand Up @@ -230,13 +230,6 @@ __device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) {
frag_b[1] = __hsub2(frag_b[1], zp);
}

// Given 2 floats multiply by 2 scales (halves)
__device__ inline void scale_float(float* c, FragS& s) {
__half* s_ptr = reinterpret_cast<__half*>(&s);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
}

// Same as above, but for act_order (each K is multiplied individually)
__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
FragS& frag_s_3, FragS& frag_s_4, int i) {
Expand All @@ -252,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
}

// Given 2 floats multiply by 2 scales (halves)
__device__ inline void scale_float(float* c, FragS& s) {
__half* s_ptr = reinterpret_cast<__half*>(&s);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
}

// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
Expand Down Expand Up @@ -440,6 +440,7 @@ __device__ void MarlinMoESingle(
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;

// Scale size/strides with act_order
constexpr int tb_k = 16 * thread_k_blocks;
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
Expand Down Expand Up @@ -536,12 +537,6 @@ __device__ void MarlinMoESingle(
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;

int shs_size;
if constexpr (has_act_order)
shs_size = sh_max_num_groups * s_sh_stride + threads;
else
shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;

extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh;
Expand Down Expand Up @@ -674,6 +669,7 @@ __device__ void MarlinMoESingle(
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
}

B_ptr[i] += b_gl_rd_delta_o;
}

Expand Down Expand Up @@ -990,6 +986,10 @@ __device__ void MarlinMoESingle(

FragB frag_b0 = dequant<w_type_id>(b_quant_0);
FragB frag_b1 = dequant<w_type_id>(b_quant_1);
// Apply zero-point to frag_b0
if constexpr (has_zp) {
sub_zp(frag_b0, frag_zp[j], 0);
}

// Apply scale to frag_b0
if constexpr (has_act_order) {
Expand Down Expand Up @@ -1193,6 +1193,7 @@ __device__ void MarlinMoESingle(

((half2*)sh)[idx] = res;
};

if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
Expand Down Expand Up @@ -1277,6 +1278,7 @@ __device__ void MarlinMoESingle(
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.

#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
Expand Down Expand Up @@ -1420,6 +1422,7 @@ __device__ void MarlinMoESingle(
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
}

start_pipes();
}
}
Expand Down
13 changes: 3 additions & 10 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -436,16 +436,9 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
int4* C_ptr = (int4*)C;
const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr =
(const int4*)s +
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
prob_n / 8) *
expert_idx;
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
const int4* zp_ptr =
(const int4*)zp +
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
prob_n / 4) *
expert_idx;
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace;
Expand Down Expand Up @@ -570,7 +563,7 @@ torch::Tensor marlin_gemm_moe(
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_scales.size(2),
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2258,7 +2258,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1),
"b_zeros dim 1 = ", b_zeros.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor);
}

Expand Down
66 changes: 28 additions & 38 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,24 @@
import torch

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.scalar_type import scalar_types
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize
)
awq_marlin_quantize)
from vllm.scalar_type import scalar_types


def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))


def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
Expand Down Expand Up @@ -55,7 +56,7 @@ def torch_moe_single(a, w, score, topk):
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)

@pytest.mark.skip("TODO")

@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
Expand All @@ -77,8 +78,7 @@ def test_fused_marlin_moe_awq(
if topk > e:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
Expand Down Expand Up @@ -123,15 +123,6 @@ def test_fused_marlin_moe_awq(
score = torch.randn((m, e), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, False)

triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
a,
qweight1,
Expand All @@ -146,25 +137,26 @@ def test_fused_marlin_moe_awq(
num_bits=num_bits,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2
torch_output = torch_moe(
a,
w_ref1.transpose(1, 2),
w_ref2.transpose(1, 2),
score,
topk,
)

assert compute_max_diff(marlin_output, torch_output) < 4e-2


# @pytest.mark.skip("This test is here for the sake of debugging, "
# "don't run it in automated tests.")
# @pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
# @pytest.mark.parametrize("n", [128, 2048, 256, 1024])
# @pytest.mark.parametrize("k", [128, 1024, 512])
# @pytest.mark.parametrize("e", [4, 8, 64])
# @pytest.mark.parametrize("topk", [2, 6])
# @pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
# @pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("m", [1])
@pytest.mark.parametrize("n", [128])
@pytest.mark.parametrize("k", [128])
@pytest.mark.parametrize("e", [4])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("num_bits", [4])
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_single_marlin_moe_multiply_awq(
m: int,
n: int,
Expand All @@ -174,11 +166,12 @@ def test_single_marlin_moe_multiply_awq(
group_size: int,
num_bits: int,
):
torch.manual_seed(7)

if topk > e:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
Expand All @@ -198,11 +191,8 @@ def test_single_marlin_moe_multiply_awq(

w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l)
zp = stack_and_dev(zp_l)

print(scales.dtype)
print(zp.dtype)
scales = stack_and_dev(scales_l).contiguous()
zp = stack_and_dev(zp_l).contiguous()

score = torch.randn((m, e), device="cuda", dtype=dtype)

Expand Down

0 comments on commit fa23e51

Please sign in to comment.