Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Hardware][AMD] Add support for GGUF quantization on ROCm #10254

Merged
merged 10 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ if [[ $commands == *" kernels "* ]]; then
--ignore=kernels/test_encoder_decoder_attn.py \
--ignore=kernels/test_flash_attn.py \
--ignore=kernels/test_flashinfer.py \
--ignore=kernels/test_gguf.py \
--ignore=kernels/test_int8_quant.py \
--ignore=kernels/test_machete_gemm.py \
--ignore=kernels/test_mamba_ssm.py \
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")
Expand Down Expand Up @@ -237,7 +238,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
int64_t thx, int64_t thy);

torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
#endif

torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);
Expand All @@ -138,6 +139,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);

#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
Expand Down
17 changes: 16 additions & 1 deletion csrc/quantization/gguf/ggml-common.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE 32
#define WARP_SIZE_GGUF 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to change this name? It seems the vast majority of the changes in this PR are due to this rename, so would prefer to keep it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The WARP_SIZE macro here conflicts with one defined in "cuda_compat.h" which has cross-platform utilities that this port uses. For CUDA the macro redefinition may be just fine because they are replaced by the same values. But for ROCm, in the gguf quantization kernel, some of the symbols need to be replaced by values different from that in "cuda_compat.h" while the others are kept the same, to utilize full waves on wave 64 devices. So I thought to change the macro name referenced in the quantization for clarity, though it does make quite a few mundane changes.

Another way could be to change the macro through perhaps undefine/define to the desired values at their respective places. Wouldn't need to change names in those places if we do this but the same symbol would represent different values in different places. I'm fine with either options, and would also be interested in knowing which one you would prefer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's clear justification, thank you for your thoughts. I prefer the gguf name as it is more explicit, so if we must make changes with it then let's keep what you have

#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
Expand Down Expand Up @@ -1112,4 +1112,19 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return c;
}

static __device__ __forceinline__ uint32_t __vcmpeq4(const uint32_t a, const uint32_t b) {
uint32_t neq = a^b;
return !(neq & 0xff000000) * 0xff000000 |
!(neq & 0x00ff0000) * 0x00ff0000 |
!(neq & 0x0000ff00) * 0x0000ff00 |
!(neq & 0x000000ff) * 0x000000ff;
}

static __device__ __forceinline__ uint32_t __vsub4(const uint32_t a, const uint32_t b) {
return (static_cast<uint8_t>(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) +
(static_cast<uint8_t>(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) +
(static_cast<uint8_t>(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) +
(static_cast<uint8_t>(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0);
}
#endif // defined(USE_ROCM)
6 changes: 4 additions & 2 deletions csrc/quantization/gguf/gguf_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_compat.h"

#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
Expand Down Expand Up @@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,

#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
amax = fmaxf(amax, VLLM_SHFL_XOR_SYNC_WIDTH(amax, mask, 32));
sum += VLLM_SHFL_XOR_SYNC_WIDTH(sum, mask, 32);
}

const float d = amax / 127;
Expand Down
70 changes: 35 additions & 35 deletions csrc/quantization/gguf/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ static __device__ __forceinline__ void mul_mat_q(

const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
const int blocks_per_warp = WARP_SIZE / qi;
const int blocks_per_warp = WARP_SIZE_GGUF / qi;

const int & ncols_dst = ncols_y;

Expand All @@ -27,10 +27,10 @@ static __device__ __forceinline__ void mul_mat_q(

allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);

__shared__ int tile_y_qs[mmq_x * WARP_SIZE];
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
__shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF];
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF/QI8_1];

float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
float sum[mmq_y/WARP_SIZE_GGUF][mmq_x/nwarps] = {{0.0f}};

for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {

Expand All @@ -39,26 +39,26 @@ static __device__ __forceinline__ void mul_mat_q(

#pragma unroll
for (int ir = 0; ir < qr; ++ir) {
const int kqs = ir*WARP_SIZE + threadIdx.x;
const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
const int kbxd = kqs / QI8_1;

#pragma unroll
for (int i = 0; i < mmq_x; i += nwarps) {
const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
const int index_y = (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF;
tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
}

#pragma unroll
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x;
const int kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);

// if the sum is not needed it's faster to transform the scale to f32 ahead of time
const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE_GGUF/QI8_1) + kby].ds;
half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE_GGUF/QI8_1) + kby];
if (need_sum) {
*dsi_dst = *dsi_src;
} else {
Expand All @@ -70,12 +70,12 @@ static __device__ __forceinline__ void mul_mat_q(
__syncthreads();

// #pragma unroll // unrolling this loop causes too much register pressure
for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
for (int k = ir*WARP_SIZE_GGUF/qr; k < (ir+1)*WARP_SIZE_GGUF/qr; k += vdr) {
#pragma unroll
for (int j = 0; j < mmq_x; j += nwarps) {
#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
sum[i/WARP_SIZE][j/nwarps] += vec_dot(
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
sum[i/WARP_SIZE_GGUF][j/nwarps] += vec_dot(
tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
threadIdx.x + i, threadIdx.y + j, k);
}
Expand All @@ -93,12 +93,12 @@ static __device__ __forceinline__ void mul_mat_q(
}

#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
const int row_dst = row_dst_0 + threadIdx.x + i;
if (row_dst >= nrows_dst) {
continue;
}
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]);
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]);
}
}
}
Expand All @@ -115,7 +115,7 @@ static __device__ __forceinline__ void mul_mat_q(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2)
#endif
mul_mat_q4_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -140,7 +140,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -165,7 +165,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2)
#endif
mul_mat_q4_1(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -190,7 +190,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -215,7 +215,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2)
#endif
mul_mat_q5_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -240,7 +240,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -265,7 +265,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2)
#endif
mul_mat_q5_1(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -289,7 +289,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -314,7 +314,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2)
#endif
mul_mat_q8_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -338,7 +338,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -363,7 +363,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2)
#endif
mul_mat_q2_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -387,7 +387,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -412,7 +412,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2)
#endif
mul_mat_q3_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -438,7 +438,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -463,7 +463,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2)
#endif
mul_mat_q4_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -487,7 +487,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -512,7 +512,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2)
#endif
mul_mat_q5_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -537,7 +537,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -562,7 +562,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2)
#endif
mul_mat_q6_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -586,7 +586,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/gguf/mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *

// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) {
tmp += VLLM_SHFL_XOR_SYNC(tmp, mask);
}

if (threadIdx.x == 0) {
Expand Down
Loading