From 71cedee57c9b4f79c3ae3d28b78e9a6cf2e6fe0e Mon Sep 17 00:00:00 2001 From: miloice Date: Fri, 30 Aug 2024 06:11:34 +0000 Subject: [PATCH 1/6] initial port --- CMakeLists.txt | 2 +- csrc/ops.h | 2 ++ csrc/quantization/gguf/ggml-common.h | 15 +++++++++ csrc/quantization/gguf/gguf_kernel.cu | 6 ++-- csrc/quantization/gguf/mmvq.cuh | 4 +++ csrc/quantization/gguf/vecdotq.cuh | 44 +++++++++++++-------------- csrc/torch_bindings.cpp | 2 ++ vllm/config.py | 2 +- 8 files changed, 51 insertions(+), 26 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 923ed084ffd9e..ac9ca36678bcf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,6 +185,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" + "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/prepare_inputs/advance_step.cu" @@ -213,7 +214,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" - "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/fp8/fp8_marlin.cu" "csrc/custom_all_reduce.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" diff --git a/csrc/ops.h b/csrc/ops.h index 8d24545de898d..01e890aaff5ef 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -125,6 +125,7 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits); +#endif torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n); @@ -135,6 +136,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); +#ifndef USE_ROCM torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h index d7989d84bf68e..d78e7dc93c662 100644 --- a/csrc/quantization/gguf/ggml-common.h +++ b/csrc/quantization/gguf/ggml-common.h @@ -966,4 +966,19 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #endif return c; } + +static __device__ __forceinline__ uint32_t __vcmpeq4(const uint32_t a, const uint32_t b) { + uint32_t neq = a^b; + return !(neq & 0xff000000) * 0xff000000 | + !(neq & 0x00ff0000) * 0x00ff0000 | + !(neq & 0x0000ff00) * 0x0000ff00 | + !(neq & 0x000000ff) * 0x000000ff; +} + +static __device__ __forceinline__ uint32_t __vsub4(const uint32_t a, const uint32_t b) { + return (static_cast(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) + + (static_cast(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) + + (static_cast(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) + + (static_cast(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0); +} #endif // defined(USE_ROCM) diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 966d9992b25fd..c6eea763e457d 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -10,6 +10,8 @@ #include "mmvq.cuh" #include "mmq.cuh" +#include "cuda_compat.h" + // Q8 gemv static __global__ void quantize_q8_1(const half* __restrict__ x, void* __restrict__ vy, const int kx, @@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x, #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); - sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); + amax = fmaxf(amax, VLLM_SHFL_XOR_SYNC_WIDTH(amax, mask, 32)); + sum += VLLM_SHFL_XOR_SYNC_WIDTH(sum, mask, 32); } const float d = amax / 127; diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index ef2ea072392d2..8a9a5f3ba29a2 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -29,7 +29,11 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * // sum up partial sums and write back result #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { +#ifndef USE_ROCM tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); +#else + tmp += __shfl_xor(tmp, mask, 32); +#endif } if (threadIdx.x == 0) { diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index 78c749d3f3bc1..70ccb207cca40 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -30,7 +30,7 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( const int * v, const int * u, const float & d4, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -55,7 +55,7 @@ template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( const int * v, const int * u, const half2 & dm4, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -82,7 +82,7 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -115,7 +115,7 @@ template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -149,7 +149,7 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( const int * v, const int * u, const float & d8_0, const float & d8_1) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -163,7 +163,7 @@ template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( const int * v, const int * u, const half2 & dm8, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; @@ -189,7 +189,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, const half2 & dm2, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -217,7 +217,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, const half2 & dm2, const float & d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi_d = 0; int sumi_m = 0; @@ -254,7 +254,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, const int & scale_offset, const float & d3, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf = 0.0f; @@ -288,7 +288,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d3, const float & d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -313,7 +313,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -338,7 +338,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -369,7 +369,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -400,7 +400,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -432,7 +432,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf = 0.0f; #pragma unroll @@ -452,7 +452,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const float & d6, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; #pragma unroll @@ -1569,7 +1569,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq2_s * bq2 = (const block_iq2_s *) vbq; const int ib32 = iqs; @@ -1606,7 +1606,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; const int ib32 = iqs; @@ -1633,7 +1633,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq3_s * bq2 = (const block_iq3_s *) vbq; const int ib32 = iqs; @@ -1658,7 +1658,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const int ib32 = iqs; @@ -1698,7 +1698,7 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4 static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq4_nl * bq = (const block_iq4_nl *) vbq; @@ -1723,7 +1723,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq; const uint8_t * values = (const uint8_t *)kvalues_iq4nl; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7783acd741f5f..43b27bed5ee44 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -159,6 +159,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // awq_marlin repack from AWQ. ops.def("awq_marlin_repack", &awq_marlin_repack); ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); +#endif // Dequantization for GGML. ops.def("ggml_dequantize", &ggml_dequantize); @@ -172,6 +173,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8); ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); +#ifndef USE_ROCM // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. ops.def("fp8_marlin_gemm", &fp8_marlin_gemm); ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); diff --git a/vllm/config.py b/vllm/config.py index 0a34dabf57e7c..ffca879e8a29a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -267,7 +267,7 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"] + rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8", "gguf"] optimized_quantization_methods = [ "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", From 0ca2077f442ff0af6cbc831ed054453cb3f23aa1 Mon Sep 17 00:00:00 2001 From: miloice Date: Sat, 31 Aug 2024 08:20:06 +0000 Subject: [PATCH 2/6] format --- csrc/quantization/gguf/ggml-common.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h index d78e7dc93c662..224e99ce14caf 100644 --- a/csrc/quantization/gguf/ggml-common.h +++ b/csrc/quantization/gguf/ggml-common.h @@ -976,9 +976,9 @@ static __device__ __forceinline__ uint32_t __vcmpeq4(const uint32_t a, const uin } static __device__ __forceinline__ uint32_t __vsub4(const uint32_t a, const uint32_t b) { - return (static_cast(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) + - (static_cast(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) + - (static_cast(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) + - (static_cast(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0); + return (static_cast(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) + + (static_cast(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) + + (static_cast(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) + + (static_cast(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0); } #endif // defined(USE_ROCM) From 02dbde6349a8d1d518af42018accb44d335cbb03 Mon Sep 17 00:00:00 2001 From: miloice Date: Tue, 5 Nov 2024 11:01:12 +0000 Subject: [PATCH 3/6] fix warp size 32 --- csrc/ops.h | 2 + csrc/quantization/gguf/ggml-common.h | 2 +- csrc/quantization/gguf/mmq.cuh | 70 ++++---- csrc/quantization/gguf/mmvq.cuh | 40 ++--- csrc/quantization/gguf/vecdotq.cuh | 242 +++++++++++++-------------- csrc/torch_bindings.cpp | 1 + vllm/_custom_ops.py | 51 +++--- 7 files changed, 206 insertions(+), 202 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index c50eb39a3dacc..604146deb17a5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -118,6 +118,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, int64_t thx, int64_t thy); torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); +#endif torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n); @@ -128,6 +129,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); +#ifndef USE_ROCM bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h index 21eb046abd48d..d42205a6571db 100644 --- a/csrc/quantization/gguf/ggml-common.h +++ b/csrc/quantization/gguf/ggml-common.h @@ -1,7 +1,7 @@ // copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h #define QK_K 256 #define K_QUANTS_PER_ITERATION 2 -#define WARP_SIZE 32 +#define WARP_SIZE_GGUF 32 #define K_SCALE_SIZE 12 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_QUANTIZE_BLOCK_SIZE 256 diff --git a/csrc/quantization/gguf/mmq.cuh b/csrc/quantization/gguf/mmq.cuh index d13efd5965313..c935faa07df0c 100644 --- a/csrc/quantization/gguf/mmq.cuh +++ b/csrc/quantization/gguf/mmq.cuh @@ -10,7 +10,7 @@ static __device__ __forceinline__ void mul_mat_q( const int blocks_per_row_x = ncols_x / qk; const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_warp = WARP_SIZE / qi; + const int blocks_per_warp = WARP_SIZE_GGUF / qi; const int & ncols_dst = ncols_y; @@ -27,10 +27,10 @@ static __device__ __forceinline__ void mul_mat_q( allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); - __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; - __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; + __shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF]; + __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF/QI8_1]; - float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; + float sum[mmq_y/WARP_SIZE_GGUF][mmq_x/nwarps] = {{0.0f}}; for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { @@ -39,26 +39,26 @@ static __device__ __forceinline__ void mul_mat_q( #pragma unroll for (int ir = 0; ir < qr; ++ir) { - const int kqs = ir*WARP_SIZE + threadIdx.x; + const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x; const int kbxd = kqs / QI8_1; #pragma unroll for (int i = 0; i < mmq_x; i += nwarps) { const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; - const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; + const int index_y = (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF; tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); } #pragma unroll for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; - const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1); const int col_y_eff = min(col_y_0 + ids, ncols_y-1); // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; - half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE_GGUF/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE_GGUF/QI8_1) + kby]; if (need_sum) { *dsi_dst = *dsi_src; } else { @@ -70,12 +70,12 @@ static __device__ __forceinline__ void mul_mat_q( __syncthreads(); // #pragma unroll // unrolling this loop causes too much register pressure - for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { + for (int k = ir*WARP_SIZE_GGUF/qr; k < (ir+1)*WARP_SIZE_GGUF/qr; k += vdr) { #pragma unroll for (int j = 0; j < mmq_x; j += nwarps) { #pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - sum[i/WARP_SIZE][j/nwarps] += vec_dot( + for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) { + sum[i/WARP_SIZE_GGUF][j/nwarps] += vec_dot( tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, threadIdx.x + i, threadIdx.y + j, k); } @@ -93,12 +93,12 @@ static __device__ __forceinline__ void mul_mat_q( } #pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { + for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) { const int row_dst = row_dst_0 + threadIdx.x + i; if (row_dst >= nrows_dst) { continue; } - dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]); + dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]); } } } @@ -115,7 +115,7 @@ static __device__ __forceinline__ void mul_mat_q( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2) #endif mul_mat_q4_0( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -140,7 +140,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -165,7 +165,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2) #endif mul_mat_q4_1( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -190,7 +190,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -215,7 +215,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2) #endif mul_mat_q5_0( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -240,7 +240,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -265,7 +265,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2) #endif mul_mat_q5_1( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -289,7 +289,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -314,7 +314,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2) #endif mul_mat_q8_0( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -338,7 +338,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -363,7 +363,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2) #endif mul_mat_q2_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -387,7 +387,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -412,7 +412,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2) #endif mul_mat_q3_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -438,7 +438,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -463,7 +463,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2) #endif mul_mat_q4_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -487,7 +487,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -512,7 +512,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2) #endif mul_mat_q5_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -537,7 +537,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -562,7 +562,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2) #endif mul_mat_q6_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -586,7 +586,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index 77e17a8128e58..82dd69c335fcd 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -8,7 +8,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; + const int blocks_per_warp = vdr * WARP_SIZE_GGUF / qi; // partial sum for each thread float tmp = 0.0f; @@ -44,7 +44,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -52,7 +52,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -60,7 +60,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -68,7 +68,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -76,7 +76,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -84,7 +84,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -92,7 +92,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -100,7 +100,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -108,7 +108,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -116,7 +116,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -124,7 +124,7 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -132,7 +132,7 @@ static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -140,7 +140,7 @@ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -148,7 +148,7 @@ static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -156,7 +156,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -164,7 +164,7 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -172,7 +172,7 @@ static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -180,7 +180,7 @@ static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -188,7 +188,7 @@ static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index 5c82c1ca8558f..e00422637c65b 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -507,8 +507,8 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI4_0) + mmq_y/QI4_0]; *x_ql = tile_x_qs; *x_dm = (half2 *) tile_x_d; } @@ -529,11 +529,11 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbx] = bxi->d; } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_0; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -543,7 +543,7 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d); } } @@ -559,13 +559,13 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_0) % WARP_SIZE_GGUF]; } return vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (WARP_SIZE_GGUF + 1) + k], u, x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i/QI4_0 + k/QI4_0], + y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q4_1_q8_1( @@ -587,8 +587,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI4_1) + mmq_y/QI4_1]; *x_ql = tile_x_qs; *x_dm = tile_x_dm; } @@ -608,10 +608,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_1; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -621,7 +621,7 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; } } @@ -634,13 +634,13 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_1) % WARP_SIZE_GGUF]; } return vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (WARP_SIZE_GGUF + 1) + k], u, x_dm[i * (WARP_SIZE_GGUF/QI4_1) + i/QI4_1 + k/QI4_1], + y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q5_0_q8_1( @@ -664,8 +664,8 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI5_0) + mmq_y/QI5_0]; *x_ql = tile_x_ql; *x_dm = (half2 *) tile_x_d; @@ -697,7 +697,7 @@ template static __device__ __forceinlin qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+0] = qs0; int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 @@ -706,10 +706,10 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+1] = qs1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_0; const int kbxd = k % blocks_per_tile_x_row; float * x_dmf = (float *) x_dm; @@ -722,7 +722,7 @@ template static __device__ __forceinlin } const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d); } } @@ -730,7 +730,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; + const int index_bx = i * (WARP_SIZE_GGUF/QI5_0) + i/QI5_0 + k/QI5_0; const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; @@ -738,12 +738,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_0) % WARP_SIZE_GGUF]; } return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q5_1_q8_1( @@ -767,8 +767,8 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI5_1) + mmq_y/QI5_1]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -801,7 +801,7 @@ template static __device__ __forceinlin qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+0] = qs0; int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 @@ -809,10 +809,10 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+1] = qs1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_1; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -825,7 +825,7 @@ template static __device__ __forceinlin const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; } } @@ -833,18 +833,18 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; + const int index_bx = i * (WARP_SIZE_GGUF/QI5_1) + + i/QI5_1 + k/QI5_1; int u[2*VDR_Q5_1_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_1) % WARP_SIZE_GGUF]; } return vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q8_0_q8_1( @@ -865,8 +865,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI8_0) + mmq_y/QI8_0]; *x_ql = tile_x_qs; *x_dm = (half2 *) tile_x_d; @@ -889,10 +889,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_int8(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI8_0; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -903,7 +903,7 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d); } } @@ -914,8 +914,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( const float * y_df = (const float *) y_ds; return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); + (&x_ql[i * (WARP_SIZE_GGUF + 1) + k], &y_qs[j * WARP_SIZE_GGUF + k], x_dmf[i * (WARP_SIZE_GGUF/QI8_0) + i/QI8_0 + k/QI8_0], + y_df[j * (WARP_SIZE_GGUF/QI8_1) + k/QI8_1]); } static __device__ __forceinline__ float vec_dot_q2_K_q8_1( @@ -942,9 +942,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI2_K) + mmq_y/QI2_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/4) + mmq_y/4]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -967,10 +967,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI2_K; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -981,18 +981,18 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI2_K) + i / QI2_K + kbxd] = bxi->dm; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF/4); if (need_check) { i = min(i, i_max); } - const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); + const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/4)) / (QI2_K/4); + x_sc[i * (WARP_SIZE_GGUF/4) + i / 4 + k % (WARP_SIZE_GGUF/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); } } @@ -1005,7 +1005,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); #pragma unroll @@ -1013,10 +1013,10 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; } - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/4) + i/4 + kbx*4]) + ky/4; - const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; - return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); + const int index_y = j * WARP_SIZE_GGUF + (QR2_K*k) % WARP_SIZE_GGUF; + return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE_GGUF/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q3_K_q8_1( @@ -1047,10 +1047,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; - __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI3_K) + mmq_y/QI3_K]; + __shared__ int tile_x_qh[mmq_y * (WARP_SIZE_GGUF/2) + mmq_y/2]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/4) + mmq_y/4]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1073,10 +1073,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI3_K; const int kbxd = k % blocks_per_tile_x_row; float * x_dmf = (float *) x_dm; @@ -1087,27 +1087,27 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d); } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { - int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); + int i = i0 + i_offset * 2 + k / (WARP_SIZE_GGUF/2); if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/2)) / (QI3_K/2); // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); + x_qh[i * (WARP_SIZE_GGUF/2) + i / 2 + k % (WARP_SIZE_GGUF/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF/4); if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/4)) / (QI3_K/4); const int ksc = k % (QI3_K/4); @@ -1121,7 +1121,7 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; + x_sc[i * (WARP_SIZE_GGUF/4) + i / 4 + k % (WARP_SIZE_GGUF/4)] = sc; } } @@ -1134,24 +1134,24 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; - const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE_GGUF/4) + i/4 + kbx*4)) + ky/4; int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); const int shift = 2 * ((ky % 32) / 8); const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vh = x_qh[i * (WARP_SIZE_GGUF/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); const int vlh = (vh << 2) & 0x04040404; v[l] = __vsubss4(vll, vlh); } - const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; - return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); + const int index_y = j * WARP_SIZE_GGUF + (k*QR3_K) % WARP_SIZE_GGUF; + return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE_GGUF/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q4_K_q8_1( @@ -1200,9 +1200,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI4_K) + mmq_y/QI4_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1225,10 +1225,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll @@ -1238,27 +1238,27 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI4_K) + i / QI4_K + kbxd] = bxi->dm; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; - const int ksc = k % (WARP_SIZE/8); + const int ksc = k % (WARP_SIZE_GGUF/8); // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + ksc] = scales8; } } @@ -1267,11 +1267,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/16]) + 2*((k % 16) / 8); - const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; - return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); + const int index_y = j * WARP_SIZE_GGUF + (QR4_K*k) % WARP_SIZE_GGUF; + return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE_GGUF + 1) + k], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE_GGUF/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q5_K_q8_1( @@ -1321,9 +1321,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI5_K) + mmq_y/QI5_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1360,11 +1360,11 @@ template static __device__ __forceinlin const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); - x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq1] = ql1 | qh1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll @@ -1376,40 +1376,40 @@ template static __device__ __forceinlin } const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI5_K) + i / QI5_K + kbxd] = bxi->dm; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); + const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / (QI5_K/8); const int * scales = (const int *) bxi->scales; - const int ksc = k % (WARP_SIZE/8); + const int ksc = k % (WARP_SIZE_GGUF/8); // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + ksc] = scales8; } } static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; - const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; + const int index_x = i * (QR5_K*WARP_SIZE_GGUF + 1) + QR5_K*k; + const int index_y = j * WARP_SIZE_GGUF + (QR5_K*k) % WARP_SIZE_GGUF; return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); + x_dm[i * (WARP_SIZE_GGUF/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q6_K_q8_1( @@ -1439,9 +1439,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI6_K) + mmq_y/QI6_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1478,11 +1478,11 @@ template static __device__ __forceinlin const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); - x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); } - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI6_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 float * x_dmf = (float *) x_dm; @@ -1496,20 +1496,20 @@ template static __device__ __forceinlin const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d); } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; + const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / 4; - x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); + x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + k % (WARP_SIZE_GGUF/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); } } @@ -1519,11 +1519,11 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/8]); - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; - const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; - return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); + const int index_x = i * (QR6_K*WARP_SIZE_GGUF + 1) + QR6_K*k; + const int index_y = j * WARP_SIZE_GGUF + (QR6_K*k) % WARP_SIZE_GGUF; + return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE_GGUF/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( @@ -1703,7 +1703,7 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq1_m * bq1 = (const block_iq1_m *) vbq; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6b6dcbf9faa4a..36d1f1af15a6a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -218,6 +218,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); // conditionally compiled so impl registrations are in source file +#endif // Dequantization for GGML. ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 46a2fb8bc80a2..41ff7d6bd7c31 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -345,31 +345,6 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::ggml_dequantize") - def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, - m: torch.SymInt, - n: torch.SymInt) -> torch.Tensor: - return torch.empty((m, n), dtype=torch.float16, device=W.device) - - @register_fake("_C::ggml_mul_mat_vec_a8") - def _ggml_mul_mat_vec_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((1, row), dtype=torch.float16, device=W.device) - - @register_fake("_C::ggml_mul_mat_a8") - def _ggml_mul_mat_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - batch = X.size(0) - return torch.empty((batch, row), dtype=torch.float16, device=W.device) - @register_fake("_C::marlin_qqq_gemm") def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, @@ -467,6 +442,32 @@ def machete_prepack_B_fake(b_q_weight: torch.Tensor, return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) +if hasattr(torch.ops._C, "ggml_dequantize"): + + @register_fake("_C::ggml_dequantize") + def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, + m: torch.SymInt, + n: torch.SymInt) -> torch.Tensor: + return torch.empty((m, n), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_vec_a8") + def _ggml_mul_mat_vec_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + return torch.empty((1, row), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_a8") + def _ggml_mul_mat_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + batch = X.size(0) + return torch.empty((batch, row), dtype=torch.float16, device=W.device) # cutlass def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: From 03a9216a98796a846fb9898089a12933be0534d2 Mon Sep 17 00:00:00 2001 From: miloice Date: Thu, 7 Nov 2024 09:36:35 +0000 Subject: [PATCH 4/6] run mmvq kernels on 64 thread warps on rocm --- csrc/quantization/gguf/gguf_kernel.cu | 4 +-- csrc/quantization/gguf/mmvq.cuh | 48 ++++++++++++--------------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 54c5267c24ee2..5f0eaf5a973fb 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -4,14 +4,14 @@ #include #include +#include "cuda_compat.h" + #include "ggml-common.h" #include "vecdotq.cuh" #include "dequantize.cuh" #include "mmvq.cuh" #include "mmq.cuh" -#include "cuda_compat.h" - // Q8 gemv static __global__ void quantize_q8_1(const half* __restrict__ x, void* __restrict__ vy, const int kx, diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index 82dd69c335fcd..b01e939808a3f 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -8,7 +8,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE_GGUF / qi; + const int blocks_per_warp = vdr * WARP_SIZE / qi; // partial sum for each thread float tmp = 0.0f; @@ -28,12 +28,8 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { -#ifndef USE_ROCM - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); -#else - tmp += __shfl_xor(tmp, mask, 32); -#endif + for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) { + tmp += VLLM_SHFL_XOR_SYNC(tmp, mask); } if (threadIdx.x == 0) { @@ -44,7 +40,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -52,7 +48,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -60,7 +56,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -68,7 +64,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -76,7 +72,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -84,7 +80,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -92,7 +88,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -100,7 +96,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -108,7 +104,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -116,7 +112,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -124,7 +120,7 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -132,7 +128,7 @@ static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -140,7 +136,7 @@ static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -148,7 +144,7 @@ static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -156,7 +152,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -164,7 +160,7 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -172,7 +168,7 @@ static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -180,7 +176,7 @@ static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -188,7 +184,7 @@ static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE_GGUF, GGML_CUDA_MMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } From 41c83ddfdc33f31aacdb8b4e9d13190c7693a032 Mon Sep 17 00:00:00 2001 From: miloice Date: Thu, 14 Nov 2024 08:24:33 +0000 Subject: [PATCH 5/6] enable gguf tests on amd --- .buildkite/run-amd-test.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 860272e71fd84..37a33431e2d33 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -83,7 +83,6 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_encoder_decoder_attn.py \ --ignore=kernels/test_flash_attn.py \ --ignore=kernels/test_flashinfer.py \ - --ignore=kernels/test_gguf.py \ --ignore=kernels/test_int8_quant.py \ --ignore=kernels/test_machete_gemm.py \ --ignore=kernels/test_mamba_ssm.py \ From 7ee595c884996e49be0e3747016db3da2c441969 Mon Sep 17 00:00:00 2001 From: miloice Date: Thu, 14 Nov 2024 08:55:52 +0000 Subject: [PATCH 6/6] format --- vllm/_custom_ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0b9afebec2c4c..6cbf03c388d13 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -442,8 +442,9 @@ def machete_prepack_B_fake(b_q_weight: torch.Tensor, return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) + if hasattr(torch.ops._C, "ggml_dequantize"): - + @register_fake("_C::ggml_dequantize") def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: torch.SymInt, @@ -469,6 +470,7 @@ def _ggml_mul_mat_a8_fake( batch = X.size(0) return torch.empty((batch, row), dtype=torch.float16, device=W.device) + # cutlass def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)