Skip to content

Commit

Permalink
Merge branch 'marlin-moe-zero-points' of https://github.com/neuralmag…
Browse files Browse the repository at this point in the history
…ic/vllm into marlin-moe-zero-points
  • Loading branch information
ElizaWszola committed Oct 2, 2024
2 parents fa4d269 + a966417 commit 91924c1
Show file tree
Hide file tree
Showing 15 changed files with 292 additions and 120 deletions.
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu"
"csrc/moe/marlin_moe_ops.cu")
endif()

Expand Down
31 changes: 0 additions & 31 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu

This file was deleted.

20 changes: 0 additions & 20 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h

This file was deleted.

16 changes: 8 additions & 8 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
#include "marlin_kernels/marlin_moe_kernel_ku8.h"

template <typename T>
inline std::string str(T x) {
Expand Down Expand Up @@ -158,6 +157,7 @@ thread_config_t small_batch_thread_configs[] = {
{128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N
{64, 64, 128}, // Reduce both 2X
};

thread_config_t large_batch_thread_configs[] = {
Expand All @@ -168,6 +168,7 @@ thread_config_t large_batch_thread_configs[] = {
{128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X
{64, 64, 128}, // Reduce N 4X, same K
};

int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
Expand Down Expand Up @@ -461,7 +462,6 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
Expand All @@ -484,13 +484,13 @@ torch::Tensor marlin_gemm_moe(
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts,
int64_t topk, int64_t moe_block_size, bool replicate_input,
bool apply_weights) {
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
b_q_type->str());
TORCH_CHECK(
*b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
Expand Down
5 changes: 2 additions & 3 deletions csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@ torch::Tensor marlin_gemm_moe(
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts,
int64_t topk, int64_t moe_block_size, bool replicate_input,
bool apply_weights);
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights);
5 changes: 2 additions & 3 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, "
"int size_n, int size_k, bool is_k_full, bool has_zp, int num_experts, "
"int topk, int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, int "
"moe_block_size, bool replicate_input, bool apply_weights) -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif
}
Expand Down
12 changes: 4 additions & 8 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_fused_marlin_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
num_bits: int,
):
torch.manual_seed(7)

quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
Expand Down Expand Up @@ -87,7 +86,6 @@ def test_fused_marlin_moe_awq(
score,
topk_weights,
topk_ids,
has_zero_point=True,
w1_zeros=zp1,
w2_zeros=zp2,
num_bits=num_bits,
Expand All @@ -112,19 +110,18 @@ def test_fused_marlin_moe_awq(
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_single_marlin_moe_multiply_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
num_bits: int,
):
torch.manual_seed(7)

quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
Expand Down Expand Up @@ -155,7 +152,6 @@ def test_single_marlin_moe_multiply_awq(
score,
topk,
renormalize=False,
has_zero_point=True,
w_zeros=zp,
num_bits=num_bits)

Expand Down
8 changes: 5 additions & 3 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,14 @@ def test_fused_marlin_moe(
device="cuda",
requires_grad=False)

zp = torch.empty((0), dtype=dtype, device="cuda", requires_grad=False)

zp = torch.empty((0, 0),
dtype=dtype,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, False, e, topk, block_size_m, True, False))
2 * n, k, True, e, topk, block_size_m, True, False))


@pytest.mark.skip("This test is here for the sake of debugging, "
Expand Down
1 change: 1 addition & 0 deletions tests/weight_loading/models-large.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantize
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
15 changes: 14 additions & 1 deletion tests/weight_loading/run_model_weight_loading_test.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
#!/bin/bash
SUCCESS=0

IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt"
while getopts "c:" OPT; do
case ${OPT} in
c )
CONFIG="$OPTARG"
;;
\? )
usage
exit 1
;;
esac
done


IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG

for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do
Expand Down
20 changes: 17 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
return output


def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype)
for e in range(num_experts):
output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k,
size_n, num_bits)
return output


def gptq_marlin_gemm(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
Expand Down Expand Up @@ -822,9 +836,9 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool,
has_zero_point: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
size_k: int, is_k_full: bool, num_experts: int,
topk: int, moe_block_size: int,
replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
dtype=a.dtype,
Expand Down
Loading

0 comments on commit 91924c1

Please sign in to comment.