diff --git a/csrc/ops.h b/csrc/ops.h index 023455f8a1530..6094599901022 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -63,12 +63,12 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, + const std::vector& codebook_partition_sizes, const std::optional& bias); -torch::Tensor aqlm_dequant(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes); +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, const torch::Tensor& codebooks, + const std::vector& codebook_partition_sizes); torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, @@ -107,13 +107,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits); -torch::Tensor ggml_dequantize(torch::Tensor W, int8_t type, int64_t m, +torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n); -torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, int8_t type, - int64_t row); +torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, + int64_t type, int64_t row); -torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int8_t type, +torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 22da5e4f08a18..79cd2c610b3c2 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input, } // Accumulate the partition sizes. -int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { +int4 accumulate_sizes(const std::vector& codebook_partition_sizes) { int4 cumulative_sizes; auto cumulative_size = &cumulative_sizes.x; - int i = 0; + size_t i = 0; int last = 0; - assert(codebook_partition_sizes.size(0) <= 4); - for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { - *cumulative_size = codebook_partition_sizes[i].item() + last; + assert(codebook_partition_sizes.size() <= 4); + for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) { + *cumulative_size = codebook_partition_sizes[i] + last; last = *cumulative_size; } // fill in the rest with unreachable. @@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, + const std::vector& codebook_partition_sizes, const std::optional& bias) { int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); int const entries = codebooks.size(1); if (nbooks == 1 && entries == (1 << 16)) { @@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, return {}; } -torch::Tensor aqlm_dequant(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes) { +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, const torch::Tensor& codebooks, + const std::vector& codebook_partition_sizes) { int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); int const entries = codebooks.size(1); const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); @@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, auto in_features = codes.size(1) * 8; auto out_features = codes.size(0); - assert(out_features = codebook_partition_sizes.sum().item()); + assert(out_features == std::accumulate(codebook_partition_sizes.begin(), + codebook_partition_sizes.end(), 0)); auto weights = torch::empty({out_features, in_features}, torch::TensorOptions() diff --git a/csrc/quantization/gguf/dequantize.cuh b/csrc/quantization/gguf/dequantize.cuh index 03c080f645f02..2069fba759ea0 100644 --- a/csrc/quantization/gguf/dequantize.cuh +++ b/csrc/quantization/gguf/dequantize.cuh @@ -487,7 +487,7 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, dequantize_block_iq4_xs<<>>(vx, y); } -static to_fp16_cuda_t ggml_get_to_fp16_cuda(int type) { +static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) { switch (type) { case 2: return dequantize_block_cuda; diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 9beae1bec4034..966d9992b25fd 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -60,7 +60,7 @@ static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx, } torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight - int8_t type, int64_t m, int64_t n) { + int64_t type, int64_t m, int64_t n) { const at::cuda::OptionalCUDAGuard device_guard(device_of(W)); auto options = torch::TensorOptions().dtype(torch::kFloat16).device(W.device()); @@ -73,7 +73,7 @@ torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight torch::Tensor X, // input - int8_t type, int64_t row) { + int64_t type, int64_t row) { int col = X.sizes()[1]; const int padded = (col + 512 - 1) / 512 * 512; const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); @@ -172,7 +172,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight torch::Tensor X, // input - int8_t type, int64_t row) { + int64_t type, int64_t row) { int col = X.sizes()[1]; int padded = (col + 512 - 1) / 512 * 512; int batch = X.sizes()[0]; @@ -239,4 +239,4 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight break; } return Y; -} \ No newline at end of file +} diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 59fe5329861cc..1f0a111a53bcc 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -17,13 +17,7 @@ logger.warning("Failed to import from vllm._C with %r", e) with contextlib.suppress(ImportError): - # ruff: noqa: F401 - import vllm._moe_C - - -def is_custom_op_supported(op_name: str) -> bool: - op, overloads = torch._C._jit_get_operation(op_name) - return op is not None + import vllm._moe_C # noqa: F401 def hint_on_error(fn): @@ -280,14 +274,14 @@ def cutlass_scaled_mm_azp(a: torch.Tensor, # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, - codebook_partition_sizes: torch.Tensor, + codebook_partition_sizes: List[int], bias: Optional[torch.Tensor]) -> torch.Tensor: return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, codebook_partition_sizes, bias) def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, - codebook_partition_sizes: torch.Tensor) -> torch.Tensor: + codebook_partition_sizes: List[int]) -> torch.Tensor: return torch.ops._C.aqlm_dequant(codes, codebooks, codebook_partition_sizes) @@ -434,25 +428,17 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # gguf -def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int): +def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, + n: int) -> torch.Tensor: return torch.ops._C.ggml_dequantize(W, quant_type, m, n) -def ggml_mul_mat_vec( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: int, -): - return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row) - - def ggml_mul_mat_vec_a8( W: torch.Tensor, X: torch.Tensor, quant_type: int, row: int, -): +) -> torch.Tensor: return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) @@ -461,7 +447,7 @@ def ggml_mul_mat_a8( X: torch.Tensor, quant_type: int, row: int, -): +) -> torch.Tensor: return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 95ff05b986ab4..c88ca340ebcc5 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -95,7 +95,7 @@ def generic_dequantize_gemm( codebooks: torch. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - output_partition_sizes: torch.IntTensor, + output_partition_sizes: List[int], bias: Optional[torch.Tensor], ) -> torch.Tensor: output_shape = input.shape[:-1] + (scales.shape[0], ) @@ -133,7 +133,7 @@ def optimized_dequantize_gemm( codebooks: torch. Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] scales: torch.Tensor, # [num_out_groups, 1, 1, 1] - output_partition_sizes: torch.IntTensor, + output_partition_sizes: List[int], bias: Optional[torch.Tensor], ) -> torch.Tensor: weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) @@ -288,10 +288,8 @@ def create_weights(self, layer: torch.nn.Module, codebooks, { # metadata indicates fixed size concatenated along dim 0 - "is_metadata": - True, - "output_partition_sizes": - torch.tensor(output_partition_sizes, device='cpu'), + "is_metadata": True, + "output_partition_sizes": output_partition_sizes }, ) @@ -334,7 +332,7 @@ def apply( codes = layer.codes scales = layer.scales output_partition_sizes = getattr(codebooks, "output_partition_sizes", - None) + []) nbooks = codes.shape[2] ingroups = codebooks.shape[3] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index aa04fcf8310bf..f456286899a53 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -212,6 +212,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: layer.g_idx.data = torch.empty((0, ), + dtype=torch.int, device=layer.g_idx.device) layer.exllama_state = ExllamaState.READY ops.gptq_shuffle(layer.qweight, layer.g_idx,