Skip to content

Commit

Permalink
Post-merge fix, remove e=4 case from unit tests to speed them up a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Sep 30, 2024
1 parent 6c4eca2 commit 091a4bb
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 21 deletions.
10 changes: 2 additions & 8 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("num_bits", [4, 8])
Expand All @@ -33,9 +33,6 @@ def test_fused_marlin_moe_awq(
):
torch.manual_seed(7)

if topk > e:
return

quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
Expand Down Expand Up @@ -112,7 +109,7 @@ def test_fused_marlin_moe_awq(
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("num_bits", [4, 8])
Expand All @@ -127,9 +124,6 @@ def test_single_marlin_moe_multiply_awq(
):
torch.manual_seed(7)

if topk > e:
return

quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
Expand Down
15 changes: 6 additions & 9 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_mixtral_moe(dtype: torch.dtype):
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
Expand All @@ -116,9 +116,6 @@ def test_fused_marlin_moe(
):
seed_everything(7)

if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
Expand Down Expand Up @@ -237,18 +234,20 @@ def test_fused_marlin_moe(
device="cuda",
requires_grad=False)

zp = torch.empty((0), dtype=dtype, device="cuda", requires_grad=False)

opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False))
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, False, e, topk, block_size_m, True, False))


@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
Expand All @@ -265,8 +264,6 @@ def test_single_marlin_moe_multiply(
num_bits: int,
is_k_full: bool,
):
if topk > e:
return

# Filter act_order
if act_order:
Expand Down
9 changes: 5 additions & 4 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,10 +819,11 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor, b_scales: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int,
is_k_full: bool, num_experts: int, topk: int,
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool,
has_zero_point: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
Expand Down

0 comments on commit 091a4bb

Please sign in to comment.