diff --git a/CMakeLists.txt b/CMakeLists.txt index df22ce47e54bf..8c66c31aa6ce4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -332,8 +332,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu" "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu deleted file mode 100644 index 7abbc45440bfc..0000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku8.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = true; - - if (false) { - } - AWQ_CALL_IF_MOE(vllm::kU8, 16, 4, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 8, 256) - AWQ_CALL_IF_MOE(vllm::kU8, 8, 4, 128) - AWQ_CALL_IF_MOE(vllm::kU8, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h deleted file mode 100644 index 03a0132aa347c..0000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index aec22281e1aff..ce6af7fa01860 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -30,7 +30,6 @@ #include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h" #include "marlin_kernels/marlin_moe_kernel_ku4.h" -#include "marlin_kernels/marlin_moe_kernel_ku8.h" template inline std::string str(T x) { @@ -158,6 +157,7 @@ thread_config_t small_batch_thread_configs[] = { {128, 64, 128}, // Reduce N 2X, same K {64, 256, 256}, // Reduce K 2X, increase N 2X {64, 128, 128}, // Reduce K 2X, same N + {64, 64, 128}, // Reduce both 2X }; thread_config_t large_batch_thread_configs[] = { @@ -168,6 +168,7 @@ thread_config_t large_batch_thread_configs[] = { {128, 128, 256}, // Reduce N 2X, increase K 2X {64, 128, 128}, // Reduce N 2X, same K {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 64, 128}, // Reduce N 4X, same K }; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, @@ -461,7 +462,6 @@ void marlin_mm_moe(const void* A, const void* B, void* C, CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -484,13 +484,13 @@ torch::Tensor marlin_gemm_moe( torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, - int64_t topk, int64_t moe_block_size, bool replicate_input, - bool apply_weights) { + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights) { + bool has_zp = b_zeros.size(1) != 0; if (has_zp) { - TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", - b_q_type->str()); + TORCH_CHECK( + *b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str()); } else { TORCH_CHECK( *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 0a54d93cedebc..0013787a623de 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -11,6 +11,5 @@ torch::Tensor marlin_gemm_moe( torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, bool has_zp, int64_t num_experts, - int64_t topk, int64_t moe_block_size, bool replicate_input, - bool apply_weights); + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 85098df34b2d0..576305d48ae47 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -15,9 +15,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " - "int size_n, int size_k, bool is_k_full, bool has_zp, int num_experts, " - "int topk, int moe_block_size, bool replicate_input, bool apply_weights)" - " -> Tensor"); + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, int " + "moe_block_size, bool replicate_input, bool apply_weights) -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 338f46cbe09fb..0738ea9b97edb 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -21,7 +21,6 @@ @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe_awq( m: int, n: int, @@ -29,11 +28,11 @@ def test_fused_marlin_moe_awq( e: int, topk: int, group_size: int, - num_bits: int, ): torch.manual_seed(7) - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) + num_bits = 4 + quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -87,7 +86,6 @@ def test_fused_marlin_moe_awq( score, topk_weights, topk_ids, - has_zero_point=True, w1_zeros=zp1, w2_zeros=zp2, num_bits=num_bits, @@ -112,7 +110,6 @@ def test_fused_marlin_moe_awq( @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.parametrize("num_bits", [4, 8]) def test_single_marlin_moe_multiply_awq( m: int, n: int, @@ -120,11 +117,11 @@ def test_single_marlin_moe_multiply_awq( e: int, topk: int, group_size: int, - num_bits: int, ): torch.manual_seed(7) - quant_type = (scalar_types.uint4 if num_bits == 4 else scalar_types.uint8) + num_bits = 4 + quant_type = scalar_types.uint4 dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 @@ -155,7 +152,6 @@ def test_single_marlin_moe_multiply_awq( score, topk, renormalize=False, - has_zero_point=True, w_zeros=zp, num_bits=num_bits) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 360ef1330bd69..b73c45b9cd198 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -234,12 +234,14 @@ def test_fused_marlin_moe( device="cuda", requires_grad=False) - zp = torch.empty((0), dtype=dtype, device="cuda", requires_grad=False) - + zp = torch.empty((0, 0), + dtype=dtype, + device="cuda", + requires_grad=False) opcheck(torch.ops._moe_C.marlin_gemm_moe, (a, qweight1, sorted_token_ids, topk_weights, topk_ids, scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, - 2 * n, k, True, False, e, topk, block_size_m, True, False)) + 2 * n, k, True, e, topk, block_size_m, True, False)) @pytest.mark.skip("This test is here for the sake of debugging, " diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 3e6eba04f1a87..5fda910fde084 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -3,3 +3,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantize compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main +awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main \ No newline at end of file diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh index 0cb45d1780c2c..e80c1d6c5849c 100755 --- a/tests/weight_loading/run_model_weight_loading_test.sh +++ b/tests/weight_loading/run_model_weight_loading_test.sh @@ -1,7 +1,20 @@ #!/bin/bash SUCCESS=0 -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt" +while getopts "c:" OPT; do + case ${OPT} in + c ) + CONFIG="$OPTARG" + ;; + \? ) + usage + exit 1 + ;; + esac +done + + +IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < $CONFIG for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" do diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bc7b4293c119e..69d468881db30 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -559,6 +559,20 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, return output +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -822,9 +836,9 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, size_m: int, size_n: int, - size_k: int, is_k_full: bool, - has_zero_point: bool, num_experts: int, topk: int, - moe_block_size: int, replicate_input: bool, + size_k: int, is_k_full: bool, num_experts: int, + topk: int, moe_block_size: int, + replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index e57b15936aa8b..5964d5a5465fd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -12,7 +12,8 @@ def get_scalar_type(num_bits: int, has_zp: bool): if has_zp: - return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + assert num_bits == 4 + return scalar_types.uint4 else: return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 @@ -24,7 +25,6 @@ def single_marlin_moe( gating_output: torch.Tensor, topk: int, renormalize: bool, - has_zero_point: bool = False, g_idx: Optional[torch.Tensor] = None, sort_indices: Optional[torch.Tensor] = None, w_zeros: Optional[torch.Tensor] = None, @@ -93,11 +93,9 @@ def single_marlin_moe( device=hidden_states.device, requires_grad=False) - if has_zero_point: - assert w_zeros is not None and w_zeros.nelement() > 0 - + has_zero_point = w_zeros is not None if w_zeros is None: - w_zeros = torch.empty((0), + w_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) @@ -119,7 +117,7 @@ def single_marlin_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K, - is_k_full, has_zero_point, E, topk, block_size_m, True, False) + is_k_full, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -133,7 +131,6 @@ def fused_marlin_moe( gating_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - has_zero_point: bool = False, g_idx1: Optional[torch.Tensor] = None, g_idx2: Optional[torch.Tensor] = None, sort_indices1: Optional[torch.Tensor] = None, @@ -187,6 +184,20 @@ def fused_marlin_moe( assert hidden_states.dtype == torch.float16 assert num_bits in [4, 8] + has_no_act_order = (g_idx1 is None and g_idx2 is None + and sort_indices1 is None and sort_indices2 is None) + has_all_act_order = (g_idx1 is not None and g_idx2 is not None + and sort_indices1 is not None + and sort_indices2 is not None) + assert has_no_act_order or has_all_act_order, ( + "g_idx and sorted_indices " + "must be all not None or must be all None") + + has_no_zp = w1_zeros is None and w2_zeros is None + has_all_zp = w1_zeros is not None and w2_zeros is not None + assert has_no_zp or has_all_zp, ("zero points must be both not None or " + "must be both None") + M, K = hidden_states.shape E = w1.shape[0] N = w2.shape[1] * 16 @@ -207,53 +218,42 @@ def fused_marlin_moe( sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + max_workspace_size = (max(2 * N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False) - if has_zero_point: - assert w1_zeros is not None and w1_zeros.nelement() > 0 - assert w2_zeros is not None and w2_zeros.nelement() > 0 - - if w1_zeros is None: - w1_zeros = torch.empty((0), + if has_no_zp: + w1_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) - if w2_zeros is None: - w2_zeros = torch.empty((0), + w2_zeros = torch.empty((0, 0), dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=False) - if g_idx1 is None: + if has_no_act_order: g_idx1 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if g_idx2 is None: g_idx2 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if sort_indices1 is None: sort_indices1 = torch.empty((0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - - if sort_indices2 is None: sort_indices2 = torch.empty((0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False) - scalar_type1 = get_scalar_type(num_bits, has_zero_point) - scalar_type2 = get_scalar_type(num_bits, has_zero_point) + scalar_type1 = get_scalar_type(num_bits, has_all_zp) + scalar_type2 = get_scalar_type(num_bits, has_all_zp) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -277,7 +277,6 @@ def fused_marlin_moe( 2 * N, K, is_k_full, - has_zero_point, E, topk, block_size_m, @@ -303,7 +302,6 @@ def fused_marlin_moe( K, N, is_k_full, - has_zero_point, E, topk, block_size_m, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index fe33b7341fd38..294fe11815c0f 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,16 +1,21 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch +from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, @@ -35,12 +40,13 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool, self.group_size = group_size self.has_zp = has_zp self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits - if weight_bits not in self.TYPE_MAP: - raise ValueError(f"Unsupported num_bits = {weight_bits}. " + if self.weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " f"Supported num_bits = {self.TYPE_MAP.keys()}") - self.quant_type = self.TYPE_MAP[weight_bits] + self.quant_type = self.TYPE_MAP[self.weight_bits] verify_marlin_supported(self.quant_type, group_size=self.group_size, @@ -98,10 +104,12 @@ def override_quantization_method(cls, hf_quant_cfg, return None def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["AWQMarlinLinearMethod"]: + prefix: str) -> Optional["QuantizeMethodBase"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -271,4 +279,182 @@ def apply( quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - bias=bias) \ No newline at end of file + bias=bias) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": + True, + "quant_method": + FusedMoeWeightScaleSupported.GROUP.value, + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + + # Why does this take the intermediate size for size_k? + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + replace_parameter(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_parameter(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, + num_bits=self.quant_config.weight_bits, + ) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 53762965732ce..9a1defa409714 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -208,6 +208,7 @@ def marlin_moe_permute_scales( device=s.device, dtype=s.dtype, ) + for e in range(num_experts): output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) return output @@ -258,6 +259,20 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) + return output + + def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 2bfe6ea09bd62..b95c0b7cd0612 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,7 +23,9 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] + mixtral_supported = [ + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin" + ] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported