From b027d06f339db1b445d26587d4171b4a289808e7 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Nov 2023 05:11:53 +0000 Subject: [PATCH 1/2] Hipifying csrc file to accomodate rocm builds --- .gitignore | 4 + csrc/activation_kernels.cu | 7 +- csrc/attention/attention_kernels.cu | 48 +++-- csrc/attention/attention_utils.cuh | 3 +- csrc/attention/dtype_bfloat16.cuh | 19 +- csrc/attention/dtype_float16.cuh | 90 ++++++++- csrc/cache_kernels.cu | 13 +- csrc/cuda_compat.h | 19 ++ csrc/cuda_utils_kernels.cu | 3 + csrc/pos_encoding_kernels.cu | 9 +- .../squeezellm/quant_cuda_kernel.cu | 76 +++++++- csrc/reduction_utils.cuh | 4 +- setup.py | 181 ++++++++++-------- 13 files changed, 361 insertions(+), 115 deletions(-) create mode 100644 csrc/cuda_compat.h diff --git a/.gitignore b/.gitignore index b531b7918c305..58572718850e6 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,7 @@ _build/ # vim swap files *.swo *.swp + +# hip files generated by PyTorch +*.hip +*_hip* diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 89d1ba2d37dd8..1cca2c5fccc18 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel( const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); - const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; } } @@ -57,7 +58,7 @@ __global__ void activation_kernel( const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 78e8d8ecd6d41..19ad43b554aff 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -23,7 +23,11 @@ #include +#ifndef USE_ROCM #define WARP_SIZE 32 +#else +#define WARP_SIZE 64 +#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -40,7 +44,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Warp leaders store the data to shared memory. @@ -59,11 +63,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); + return VLLM_SHFL_SYNC(sum, 0); } // TODO(woosuk): Merge the last two dimensions of the grid. @@ -223,7 +227,7 @@ __device__ void paged_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -235,10 +239,10 @@ __device__ void paged_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } // Broadcast the max qk value to all threads. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + qk_max = VLLM_SHFL_SYNC(qk_max, 0); // Get the sum of the exp values. float exp_sum = 0.f; @@ -326,7 +330,7 @@ __device__ void paged_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + acc += VLLM_SHFL_XOR_SYNC(acc, mask); } accs[i] = acc; } @@ -492,7 +496,7 @@ __global__ void paged_attention_v2_reduce_kernel( // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } if (lane == 0) { red_smem[warp_idx] = max_logit; @@ -502,10 +506,10 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } // Broadcast the max value to all threads. - max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); @@ -538,9 +542,10 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm +#ifndef USE_ROCM #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - vllm::paged_attention_v1_kernel, \ + (void*)vllm::paged_attention_v1_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::paged_attention_v1_kernel \ <<>>( \ @@ -557,6 +562,27 @@ __global__ void paged_attention_v2_reduce_kernel( q_stride, \ kv_block_stride, \ kv_head_stride); +#else +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + hipFuncSetAttribute( \ + (void*)vllm::paged_attention_v1_kernel, \ + hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); +#endif // TODO(woosuk): Tune NUM_THREADS. template< diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index bb7df25b14f02..ff64c4bd8f80c 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -17,6 +17,7 @@ */ #pragma once +#include "../cuda_compat.h" #include "attention_dtypes.h" #include @@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + qk += VLLM_SHFL_XOR_SYNC(qk, mask); } return qk; } diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 5786f77f7bca6..31e0cee01d2e1 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -21,8 +21,17 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include + + typedef __hip_bfloat162 __nv_bfloat162; + typedef __hip_bfloat16 __nv_bfloat16; +#endif + #include namespace vllm { @@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else - return a + b; + #ifndef USE_ROCM + return a + b; + #else + return __hadd(a, b); + #endif #endif } diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index e67921128d52b..1ec2009c74c87 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -21,6 +21,10 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" +#ifdef USE_ROCM + #include +#endif + #include namespace vllm { @@ -63,58 +67,114 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); return b; +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + return tmp.u32; +#endif } inline __device__ float half_to_float(uint16_t h) { +#ifndef USE_ROCM float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; +#else + float f; + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); + return f; +#endif } inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; +#endif } inline __device__ uint16_t float_to_half(float f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); return tmp.u16[0]; +#else + uint16_t ret; + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(ret) : "v"(f)); + return ret; +#endif } inline __device__ uint32_t float2_to_half2(float2 f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif + return tmp.u32; #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); return tmp.u32; +#endif } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + uint16_t c; + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } inline __device__ uint32_t add(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + uint32_t c; + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -157,16 +217,28 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + uint16_t c; + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + uint32_t c; + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } template<> @@ -271,9 +343,15 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { +#ifndef USE_ROCM uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; +#else + uint32_t d; + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + return d; +#endif } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { @@ -441,4 +519,4 @@ inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm +} // namespace vllm \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3ad52b1681c0c..59bacffdf4642 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" #include @@ -28,8 +29,8 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - void *src_ptr = src.data_ptr(); - void *dst_ptr = dst.data_ptr(); + char *src_ptr = static_cast(src.data_ptr()); + char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); + key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); + value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); } } @@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = __ldg(&key_cache[src_key_idx]); - values_to_store[j] = __ldg(&value_cache[src_value_idx]); + keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); + values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); } #pragma unroll diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h new file mode 100644 index 0000000000000..fcafc7e16687c --- /dev/null +++ b/csrc/cuda_compat.h @@ -0,0 +1,19 @@ +#pragma once + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane); +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif \ No newline at end of file diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index f1c30fe7ea99e..6c844a7f6c6ed 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,3 +1,6 @@ +#ifdef USE_ROCM + #include +#endif int get_device_attribute( int attribute, int device_id) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 0a5ec95f8c0d4..e1dc711778ffb 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding( // GPT-NeoX style rotary embedding. x_index = rot_offset; y_index = embed_dim + rot_offset; - cos = __ldg(cos_ptr + x_index); - sin = __ldg(sin_ptr + x_index); + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); } else { // GPT-J style rotary embedding. x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos = __ldg(cos_ptr + x_index / 2); - sin = __ldg(sin_ptr + x_index / 2); + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); } const scalar_t x = arr[x_index]; diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 1392b877397be..c2f77da15977a 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -20,9 +20,17 @@ __device__ inline unsigned int as_unsigned(int i) { // 4-bit matvec kernel (LUT-based) __global__ void NUQ4MatMulKernel( +#ifndef USE_ROCM const half2* __restrict__ vec, +#else + const __half2* __restrict__ vec, +#endif const int* __restrict__ mat, +#ifndef USE_ROCM half2* __restrict__ mul, +#else + float2* __restrict__ mul, +#endif const __half* __restrict__ lookup_table, int height, int width, @@ -35,7 +43,11 @@ __global__ void NUQ4MatMulKernel( int row = BLOCKHEIGHT4 * blockIdx.x; int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; +#ifndef USE_ROCM __shared__ half2 blockvec[blockwidth2]; +#else + __shared__ __half2 blockvec[blockwidth2]; +#endif __shared__ __half deq2[16][BLOCKWIDTH]; int off = threadIdx.x; @@ -46,8 +58,13 @@ __global__ void NUQ4MatMulKernel( } __half res; +#ifndef USE_ROCM half2 res2; half2 tmp2; +#else + __half2 res2; + __half2 tmp2; +#endif int i; int k; @@ -68,48 +85,96 @@ __global__ void NUQ4MatMulKernel( while (k < blockwidth2) { tmp1 = as_unsigned(mat[i]); +#ifndef USE_ROCM res2 = {}; tmp2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); + tmp2.x = __half_as_ushort(__float2half(0)); + tmp2.y = __half_as_ushort(__float2half(0)); +#endif lut_index1 = tmp1 & 0xF; lut_index2 = (tmp1 >> 4) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 0], res2); lut_index1 = (tmp1 >> 8) & 0xF; lut_index2 = (tmp1 >> 12) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 1], res2); lut_index1 = (tmp1 >> 16) & 0xF; lut_index2 = (tmp1 >> 20) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 2], res2); lut_index1 = (tmp1 >> 24) & 0xF; lut_index2 = (tmp1 >> 28) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 3], res2); +#ifndef USE_ROCM res = __hadd(__hadd(res2.x, res2.y), res); +#else + res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); +#endif i += width; k += 4; } // col%2 -> only set one of the two values +#ifndef USE_ROCM half2 res3 = {}; if (col % 2 == 0) { res3.x = res; } else { res3.y = res; } +#else + __half2 res3; + res3.x = __half_as_ushort(__float2half(0)); + res3.y = __half_as_ushort(__float2half(0)); + if (col % 2 == 0) { + res3.x = __half_as_ushort(res); + } else { + res3.y = __half_as_ushort(res); + } +#endif +#ifndef USE_ROCM atomicAdd(&mul[b * width / 2 + col / 2], res3); +#else + int tmp_addr = b * width / 2 + col / 2; + atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x))); + atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y))); +#endif } } @@ -136,13 +201,22 @@ void squeezellm_gemm( dim3 threads(BLOCKWIDTH); vllm::squeezellm::NUQ4MatMulKernel<<>>( +#ifndef USE_ROCM (half2*) vec.data(), +#else + (__half2*) vec.data_ptr(), +#endif mat.data_ptr(), +#ifndef USE_ROCM (half2*) mul.data(), (__half*) lookup_table.data(), +#else + (float2*) mul.data_ptr(), + (__half*) lookup_table.data_ptr(), +#endif height, width, batch, vec_height ); } #undef BLOCKWIDTH -#undef BLOCKHEIGHT4 +#undef BLOCKHEIGHT4 \ No newline at end of file diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bc35aa0424b56..59ac565c585e8 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -16,14 +16,14 @@ * limitations under the License. */ #pragma once - +#include "cuda_compat.h" namespace vllm { template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(0xffffffff, val, mask, 32); + val += VLLM_SHFL_XOR_SYNC(val, mask); return val; } diff --git a/setup.py b/setup.py index 36f4913435628..0e28b93602772 100644 --- a/setup.py +++ b/setup.py @@ -8,20 +8,31 @@ from packaging.version import parse, Version import setuptools import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) MAIN_CUDA_VERSION = "12.1" # Supported NVIDIA GPU architectures. -SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} +NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} +ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030","gfx1100"} +SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2", "-std=c++17"] +if torch.version.hip: + if ROCM_HOME is not None: + NVCC_FLAGS += [f"-DUSE_ROCM"] + +if not torch.version.hip: + if CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") + ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] @@ -61,20 +72,20 @@ def get_torch_arch_list() -> Set[str]: return set() # Filter out the invalid architectures and print a warning. - valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS}) + valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in NVIDIA_SUPPORTED_ARCHS}) arch_list = torch_arch_list.intersection(valid_archs) # If none of the specified architectures are valid, raise an error. if not arch_list: raise RuntimeError( - "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " + "None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env " f"variable ({env_arch_list}) is supported. " - f"Supported CUDA architectures are: {valid_archs}.") + f"Supported CUDA/ROCM architectures are: {valid_archs}.") invalid_arch_list = torch_arch_list - valid_archs if invalid_arch_list: warnings.warn( - f"Unsupported CUDA architectures ({invalid_arch_list}) are " + f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are " "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " - f"({env_arch_list}). Supported CUDA architectures are: " + f"({env_arch_list}). Supported CUDA/ROCM architectures are: " f"{valid_archs}.", stacklevel=2) return arch_list @@ -82,64 +93,65 @@ def get_torch_arch_list() -> Set[str]: # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() -if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) -if not compute_capabilities: - # If no GPU is specified nor available, add all supported architectures - # based on the NVCC CUDA version. - compute_capabilities = SUPPORTED_ARCHS.copy() - if nvcc_cuda_version < Version("11.1"): - compute_capabilities.remove("8.6") - if nvcc_cuda_version < Version("11.8"): - compute_capabilities.remove("8.9") - compute_capabilities.remove("9.0") - -# Validate the NVCC CUDA version. -if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") -if (nvcc_cuda_version < Version("11.1") - and any(cc.startswith("8.6") for cc in compute_capabilities)): - raise RuntimeError( - "CUDA 11.1 or higher is required for compute capability 8.6.") -if nvcc_cuda_version < Version("11.8"): - if any(cc.startswith("8.9") for cc in compute_capabilities): - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. - # However, GPUs with compute capability 8.9 can also run the code generated by - # the previous versions of CUDA 11 and targeting compute capability 8.0. - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 - # instead of 8.9. - warnings.warn( - "CUDA 11.8 or higher is required for compute capability 8.9. " - "Targeting compute capability 8.0 instead.", - stacklevel=2) - compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) - compute_capabilities.add("8.0+PTX") - if any(cc.startswith("9.0") for cc in compute_capabilities): +if not torch.version.hip: + if not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + +if not torch.version.hip: + nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) + if not compute_capabilities: + # If no GPU is specified nor available, add all supported architectures + # based on the NVCC CUDA version. + compute_capabilities = SUPPORTED_ARCHS.copy() + if nvcc_cuda_version < Version("11.1"): + compute_capabilities.remove("8.6") + if nvcc_cuda_version < Version("11.8"): + compute_capabilities.remove("8.9") + compute_capabilities.remove("9.0") + # Validate the NVCC CUDA version. + if nvcc_cuda_version < Version("11.0"): + raise RuntimeError("CUDA 11.0 or higher is required to build the package.") + if (nvcc_cuda_version < Version("11.1") + and any(cc.startswith("8.6") for cc in compute_capabilities)): raise RuntimeError( - "CUDA 11.8 or higher is required for compute capability 9.0.") + "CUDA 11.1 or higher is required for compute capability 8.6.") + if nvcc_cuda_version < Version("11.8"): + if any(cc.startswith("8.9") for cc in compute_capabilities): + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. + # However, GPUs with compute capability 8.9 can also run the code generated by + # the previous versions of CUDA 11 and targeting compute capability 8.0. + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 + # instead of 8.9. + warnings.warn( + "CUDA 11.8 or higher is required for compute capability 8.9. " + "Targeting compute capability 8.0 instead.", + stacklevel=2) + compute_capabilities = set(cc for cc in compute_capabilities + if not cc.startswith("8.9")) + compute_capabilities.add("8.0+PTX") + if any(cc.startswith("9.0") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.8 or higher is required for compute capability 9.0.") -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - num = capability[0] + capability[2] - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] - if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + # Add target compute capabilities to NVCC flags. + for capability in compute_capabilities: + num = capability[0] + capability[2] + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] + # Use NVCC threads to parallelize the build. + if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] ext_modules = [] @@ -199,18 +211,31 @@ def get_torch_arch_list() -> Set[str]: ext_modules.append(activation_extension) # Quantization kernels. -quantization_extension = CUDAExtension( - name="vllm.quantization_ops", - sources=[ - "csrc/quantization.cpp", - "csrc/quantization/awq/gemm_kernels.cu", - "csrc/quantization/squeezellm/quant_cuda_kernel.cu", - ], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) +if not torch.version.hip: + quantization_extension = CUDAExtension( + name="vllm.quantization_ops", + sources=[ + "csrc/quantization.cpp", + "csrc/quantization/awq/gemm_kernels.cu", + "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + ) +else: + quantization_extension = CUDAExtension( + name="vllm.quantization_ops", + sources=[ + "csrc/quantization.cpp", + "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + ) ext_modules.append(quantization_extension) # Misc. CUDA utils. @@ -244,10 +269,10 @@ def find_version(filepath: str) -> str: def get_vllm_version() -> str: version = find_version(get_path("vllm", "__init__.py")) - cuda_version = str(nvcc_cuda_version) - if cuda_version != MAIN_CUDA_VERSION: - cuda_version_str = cuda_version.replace(".", "")[:3] - version += f"+cu{cuda_version_str}" + # cuda_version = str(nvcc_cuda_version) + # if cuda_version != MAIN_CUDA_VERSION: + # cuda_version_str = cuda_version.replace(".", "")[:3] + # version += f"+cu{cuda_version_str}" return version From 9a1781c7ee4cbb56dce5e8415e4b0958bf0df3de Mon Sep 17 00:00:00 2001 From: TJian Date: Wed, 29 Nov 2023 21:40:31 +0800 Subject: [PATCH 2/2] Checked CUDA ROCm Compatibility (#15) Features * Auto-code path selection * support llama2 * support squeezellm rocm * add documentation amd-installation.rst. Describing how to setup vllm ROCm version * format.sh all the code * add base amd.Dockerfile --------- Co-authored-by: tjtanaa Co-authored-by: kuanfu --- csrc/quantization.cpp | 17 ++- .../getting_started/amd-installation.rst | 104 +++++++++++++ patch_xformers-0.0.22.post7.rocm.sh | 22 +++ pyproject.toml | 2 +- requirements-rocm.txt | 16 ++ rocm.Dockerfile | 64 ++++++++ .../commonpy_xformers-0.0.22.post7.rocm.patch | 13 ++ .../flashpy_xformers-0.0.22.post7.rocm.patch | 134 +++++++++++++++++ setup.py | 138 +++++++++++++----- vllm/engine/arg_utils.py | 98 ++++++++----- vllm/engine/ray_utils.py | 4 +- vllm/model_executor/layers/attention.py | 4 + .../layers/quantization/__init__.py | 8 +- .../model_executor/layers/quantization/awq.py | 6 +- .../layers/quantization/squeezellm.py | 17 ++- 15 files changed, 560 insertions(+), 87 deletions(-) create mode 100644 docs/source/getting_started/amd-installation.rst create mode 100644 patch_xformers-0.0.22.post7.rocm.sh create mode 100644 requirements-rocm.txt create mode 100644 rocm.Dockerfile create mode 100644 rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch create mode 100644 rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp index dfe17a496c780..764188597fd30 100644 --- a/csrc/quantization.cpp +++ b/csrc/quantization.cpp @@ -1,11 +1,14 @@ #include -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); + +#ifndef USE_ROCM + torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters); +#endif void squeezellm_gemm( torch::Tensor vec, @@ -14,6 +17,8 @@ void squeezellm_gemm( torch::Tensor lookup_table); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +#ifndef USE_ROCM m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); +#endif m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); } diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst new file mode 100644 index 0000000000000..862b8a1d6e89b --- /dev/null +++ b/docs/source/getting_started/amd-installation.rst @@ -0,0 +1,104 @@ +.. _installation: + +Installation with ROCm +============ + +vLLM-ROCm is here! Currently it is supporting llama-2. + +Requirements +------------ + +* OS: Linux +* Python: 3.8 -- 3.11 (Recommended 3.10 as this is the version that has been tested on.) +* GPU: MI210 +* Pytorch 2.0.1/2.1.1 +* ROCm 5.7 + + +Install with pip +---------------- + +You can install vLLM using pip: + +.. code-block:: console + + $ # (Optional) Create a new conda environment. + $ conda create -n myenv python=3.8 -y + $ conda activate myenv + + $ # Install vLLM with CUDA 12.1. + $ pip install vllm + +.. note:: + + As of now, vLLM's binaries are compiled on CUDA 12.1 by default. + However, you can install vLLM with CUDA 11.8 by running: + + .. code-block:: console + + $ # Install vLLM with CUDA 11.8. + $ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`). + $ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl + + $ # Re-install PyTorch with CUDA 11.8. + $ pip uninstall torch -y + $ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118 + + +.. _build_from_source: + +Build from source with docker +----------------- + +You can also build and install vLLM from source: + +Build a docker image from `rocm.Dockerfile`, and launch a docker container. + +.. code-block:: console + + $ docker build -f rocm.Dockerfile -t vllm-rocm . + $ docker run -it \ + --network=host \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size 8G \ + --device /dev/kfd \ + --device /dev/dri \ + -v :/app/hf_model \ + vllm-rocm \ + bash + +If you are going to setup on new pytorch+rocm5.7 docker container, you can follow the following steps. + +1. Install flash-attention-2-rocm + + If you are using Pytorch-2.0.1+rocm5.7. + + Install flash-attention-2 (v2.0.4) following the instruction from [ROCmSoftwarePlatform/flash-attention](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm) + + + If you are using Pytorch-2.1.x+rocm5.7 or Pytorch-2.2.x+rocm5.7, you don't need to apply the `hipify_python.patch`. + You can directly build the flash-attention-2. + + .. code-block:: console + + $ bash patch_torch211_flash_attn2.rocm.sh + + .. note:: + - Flash-attention-2 (v2.0.4) does not support sliding windows attention. + - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) + +2. Setup xformers==0.0.22.post7 without dependencies, and apply patches + + .. code-block:: console + + $ pip install xformers==0.0.22.post7 --no-deps + $ bash patch_xformers-0.0.22.post7.rocm.sh + +3. Build vllm. + + .. code-block:: console + $ cd vllm + $ python setup.py install # This may take 5-10 minutes. diff --git a/patch_xformers-0.0.22.post7.rocm.sh b/patch_xformers-0.0.22.post7.rocm.sh new file mode 100644 index 0000000000000..c8e58f721ae8c --- /dev/null +++ b/patch_xformers-0.0.22.post7.rocm.sh @@ -0,0 +1,22 @@ +#!/bin/bash +export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)') +export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)') + +echo $XFORMERS_FMHA_FLASH_PATH +echo $XFORMERS_FMHA_COMMON_PATH + +if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then + echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}" + patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch" + echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}" +else + echo "${XFORMERS_FMHA_FLASH_PATH} was patched before" +fi + +if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then + echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}" + patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch" + echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" +else + echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" +fi \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e3e3e389f7897..f9390ed8c52a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = [ "ninja", "packaging", "setuptools", - "torch >= 2.1.0", + # "torch >= 2.1.0", # commented out to accommodate ROCm "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-rocm.txt b/requirements-rocm.txt new file mode 100644 index 0000000000000..53fd3ea24d927 --- /dev/null +++ b/requirements-rocm.txt @@ -0,0 +1,16 @@ +ninja # For faster builds. +typing-extensions>=4.8.0 +starlette +psutil +ray >= 2.5.1 +pandas # Required for Ray data. +pyarrow # Required for Ray data. +sentencepiece # Required for LLaMA tokenizer. +numpy +tokenizers>=0.15.0 +huggingface_hub<0.18,>=0.16.4 +einops # Required for phi-1_5 +transformers >= 4.34.0 # Required for Mistral. +fastapi +uvicorn[standard] +pydantic == 1.10.13 # Required for OpenAI server. diff --git a/rocm.Dockerfile b/rocm.Dockerfile new file mode 100644 index 0000000000000..27b57097740d2 --- /dev/null +++ b/rocm.Dockerfile @@ -0,0 +1,64 @@ +FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 + +# Install some basic utilities +RUN apt-get update && apt-get install python3 python3-pip -y + +# Install some basic utilities +RUN apt-get update && apt-get install -y \ + curl \ + ca-certificates \ + sudo \ + git \ + bzip2 \ + libx11-6 \ + build-essential \ + wget \ + unzip \ + nvidia-cuda-toolkit \ + tmux \ + && rm -rf /var/lib/apt/lists/* + +### Mount Point ### +# When launching the container, mount the code directory to /app +ARG APP_MOUNT=/app +VOLUME [ ${APP_MOUNT} ] +WORKDIR ${APP_MOUNT} + +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers + +ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer +ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: +ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: +ENV PYTORCH_ROCM_ARCH=gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1101 + +# Install ROCm flash-attention +RUN mkdir libs \ + && cd libs \ + && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ + && cd flash-attention \ + && git submodule update --init \ + && sed -i -e "s/--offload-arch=native/--offload-arch=$(/opt/rocm/llvm/bin/amdgpu-offload-arch)/g" setup.py \ + && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ + && python3 setup.py install \ + && cd .. + +COPY ./ /app/vllm-rocm/ + +# RUN cd /app \ +# && cd vllm-rocm \ +# && git checkout v0.2.1.post1-rocm \ +# && python3 setup.py install \ +# && cd .. + +# RUN cd /app \ +# && mkdir dataset \ +# && cd .. + +# COPY ./benchmark_throughput.sh /app/benchmark_throughput.sh + +RUN python3 -m pip install --upgrade pip +# RUN python3 -m pip install --no-cache-dir ray[all] + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch b/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch new file mode 100644 index 0000000000000..4d7495cf13e1d --- /dev/null +++ b/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch @@ -0,0 +1,13 @@ +--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000 ++++ common.py 2023-11-28 16:14:19.846233146 +0000 +@@ -298,8 +298,8 @@ + dtype = d.query.dtype + if device_type not in cls.SUPPORTED_DEVICES: + reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") +- if device_type == "cuda" and not _built_with_cuda: +- reasons.append("xFormers wasn't build with CUDA support") ++ #if device_type == "cuda" and not _built_with_cuda: ++ # reasons.append("xFormers wasn't build with CUDA support") + if device_type == "cuda": + device_capability = torch.cuda.get_device_capability(d.device) + if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: diff --git a/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch b/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch new file mode 100644 index 0000000000000..4798f1efd461c --- /dev/null +++ b/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch @@ -0,0 +1,134 @@ +--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000 ++++ flash.py 2023-11-28 16:14:25.206128903 +0000 +@@ -31,39 +31,39 @@ + + FLASH_VERSION = "0.0.0" + try: +- try: +- from ... import _C_flashattention # type: ignore[attr-defined] +- from ..._cpp_lib import _build_metadata +- +- if _build_metadata is not None: +- FLASH_VERSION = _build_metadata.flash_version +- except ImportError: +- import flash_attn +- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention +- +- FLASH_VERSION = flash_attn.__version__ +- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) +- if flash_ver_parsed < (2, 3): +- raise ImportError("Requires 2.3 for sliding window support") ++ #try: ++ # from ... import _C_flashattention # type: ignore[attr-defined] ++ # from ..._cpp_lib import _build_metadata ++ ++ # if _build_metadata is not None: ++ # FLASH_VERSION = _build_metadata.flash_version ++ #except ImportError: ++ import flash_attn ++ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention ++ ++ FLASH_VERSION = flash_attn.__version__ ++ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) ++ # if flash_ver_parsed < (2, 3): ++ # raise ImportError("Requires 2.3 for sliding window support") + + # create library so that flash-attn goes through the PyTorch Dispatcher +- _flash_lib = torch.library.Library("xformers_flash", "DEF") ++ #_flash_lib = torch.library.Library("xformers_flash", "DEF") + +- _flash_lib.define( +- "flash_fwd(Tensor query, Tensor key, Tensor value, " +- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " +- "int max_seqlen_q, int max_seqlen_k, " +- "float p, float softmax_scale, " +- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" +- ) +- +- _flash_lib.define( +- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " +- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " +- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " +- "int max_seqlen_q, int max_seqlen_k, " +- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" +- ) ++ #_flash_lib.define( ++ # "flash_fwd(Tensor query, Tensor key, Tensor value, " ++ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " ++ # "int max_seqlen_q, int max_seqlen_k, " ++ # "float p, float softmax_scale, " ++ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" ++ #) ++ ++ #_flash_lib.define( ++ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " ++ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " ++ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " ++ # "int max_seqlen_q, int max_seqlen_k, " ++ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" ++ #) + + def _flash_fwd( + query, +@@ -98,8 +98,8 @@ + p, + softmax_scale, + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + return_softmax, + None, # rng + ) +@@ -127,8 +127,8 @@ + softmax_scale, + False, + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + return_softmax, + None, + ) +@@ -169,8 +169,8 @@ + p, + softmax_scale, + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + None, + rng_state, + ) +@@ -193,15 +193,15 @@ + softmax_scale, + False, # zero_tensors + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + None, + rng_state, + ) + return dq, dk, dv + +- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") +- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") ++ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") ++ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") + except ImportError: + pass + +@@ -348,7 +348,7 @@ + implementation. + """ + +- OPERATOR = get_operator("xformers_flash", "flash_fwd") ++ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} diff --git a/setup.py b/setup.py index 0e28b93602772..b75d0912a9bb6 100644 --- a/setup.py +++ b/setup.py @@ -16,30 +16,64 @@ # Supported NVIDIA GPU architectures. NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} -ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030","gfx1100"} -SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) +ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"} +# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2", "-std=c++17"] -if torch.version.hip: - if ROCM_HOME is not None: - NVCC_FLAGS += [f"-DUSE_ROCM"] +if torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None: + NVCC_FLAGS += ["-DUSE_ROCM"] -if not torch.version.hip: - if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") +if torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") + +def get_amdgpu_offload_arch(): + error_message = "" + command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" + try: + output = subprocess.check_output([command]) + return output.decode('utf-8').strip() + except subprocess.CalledProcessError as e: + error_message = f"Error: {e}" + except FileNotFoundError: + # If the command is not found, print an error message + error_message = f"The command {command} was not found." + + if error_message: + raise RuntimeError(error_message) + + return None + + +def get_hipcc_rocm_version(): + # Run the hipcc --version command + result = subprocess.run(['hipcc', '--version'], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True) + + # Check if the command was executed successfully + if result.returncode != 0: + print("Error running 'hipcc --version'") + return None + + # Extract the version using a regular expression + match = re.search(r'HIP version: (\S+)', result.stdout) + if match: + # Return the version string + return match.group(1) + else: + print("Could not find HIP version in the output") + return None def get_nvcc_cuda_version(cuda_dir: str) -> Version: @@ -72,7 +106,9 @@ def get_torch_arch_list() -> Set[str]: return set() # Filter out the invalid architectures and print a warning. - valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in NVIDIA_SUPPORTED_ARCHS}) + valid_archs = NVIDIA_SUPPORTED_ARCHS.union( + {s + "+PTX" + for s in NVIDIA_SUPPORTED_ARCHS}) arch_list = torch_arch_list.intersection(valid_archs) # If none of the specified architectures are valid, raise an error. if not arch_list: @@ -93,24 +129,24 @@ def get_torch_arch_list() -> Set[str]: # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() -if not torch.version.hip: - if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -if not torch.version.hip: +if torch.cuda.is_available( +) and torch.version.cuda and not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + +if torch.cuda.is_available() and torch.version.cuda: nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if not compute_capabilities: # If no GPU is specified nor available, add all supported architectures # based on the NVCC CUDA version. - compute_capabilities = SUPPORTED_ARCHS.copy() + compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy() if nvcc_cuda_version < Version("11.1"): compute_capabilities.remove("8.6") if nvcc_cuda_version < Version("11.8"): @@ -118,7 +154,8 @@ def get_torch_arch_list() -> Set[str]: compute_capabilities.remove("9.0") # Validate the NVCC CUDA version. if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") + raise RuntimeError( + "CUDA 11.0 or higher is required to build the package.") if (nvcc_cuda_version < Version("11.1") and any(cc.startswith("8.6") for cc in compute_capabilities)): raise RuntimeError( @@ -135,7 +172,7 @@ def get_torch_arch_list() -> Set[str]: "Targeting compute capability 8.0 instead.", stacklevel=2) compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) + if not cc.startswith("8.9")) compute_capabilities.add("8.0+PTX") if any(cc.startswith("9.0") for cc in compute_capabilities): raise RuntimeError( @@ -146,13 +183,22 @@ def get_torch_arch_list() -> Set[str]: num = capability[0] + capability[2] NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + NVCC_FLAGS += [ + "-gencode", f"arch=compute_{num},code=compute_{num}" + ] # Use NVCC threads to parallelize the build. if nvcc_cuda_version >= Version("11.2"): num_threads = min(os.cpu_count(), 8) NVCC_FLAGS += ["--threads", str(num_threads)] +elif torch.cuda.is_available() and torch.version.hip: + amd_arch = get_amdgpu_offload_arch() + if amd_arch not in ROCM_SUPPORTED_ARCHS: + raise RuntimeError( + f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}" + f"amdgpu_arch_found: {amd_arch}") + ext_modules = [] # Cache operations. @@ -211,7 +257,7 @@ def get_torch_arch_list() -> Set[str]: ext_modules.append(activation_extension) # Quantization kernels. -if not torch.version.hip: +if torch.cuda.is_available() and torch.version.cuda: quantization_extension = CUDAExtension( name="vllm.quantization_ops", sources=[ @@ -224,7 +270,7 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS, }, ) -else: +elif torch.cuda.is_available() and torch.version.hip: quantization_extension = CUDAExtension( name="vllm.quantization_ops", sources=[ @@ -269,10 +315,20 @@ def find_version(filepath: str) -> str: def get_vllm_version() -> str: version = find_version(get_path("vllm", "__init__.py")) - # cuda_version = str(nvcc_cuda_version) - # if cuda_version != MAIN_CUDA_VERSION: - # cuda_version_str = cuda_version.replace(".", "")[:3] - # version += f"+cu{cuda_version_str}" + + if torch.cuda.is_available() and torch.version.cuda: + cuda_version = str(nvcc_cuda_version) + if cuda_version != MAIN_CUDA_VERSION: + cuda_version_str = cuda_version.replace(".", "")[:3] + version += f"+cu{cuda_version_str}" + + elif torch.cuda.is_available() and torch.version.hip: + # Get the HIP version + hipcc_version = get_hipcc_rocm_version() + if hipcc_version != MAIN_CUDA_VERSION: + rocm_version_str = hipcc_version.replace(".", "")[:3] + version += f"+rocm{rocm_version_str}" + return version @@ -287,8 +343,14 @@ def read_readme() -> str: def get_requirements() -> List[str]: """Get Python package dependencies from requirements.txt.""" - with open(get_path("requirements.txt")) as f: - requirements = f.read().strip().split("\n") + if torch.cuda.is_available() and torch.version.hip: + with open(get_path("requirements-rocm.txt")) as f: + requirements = f.read().strip().split("\n") + elif torch.cuda.is_available() and torch.version.cuda: + with open(get_path("requirements.txt")) as f: + requirements = f.read().strip().split("\n") + print("requirements: ", requirements) + # exit() return requirements diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c7e476c704740..c7612b3ac4072 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Optional, Tuple +import torch + from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -83,32 +85,52 @@ def add_cli_args( help='directory to download and load the weights, ' 'default to the default cache dir of ' 'huggingface') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' - '"auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' - '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') - parser.add_argument( - '--dtype', - type=str, - default=EngineArgs.dtype, - choices=[ - 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' - ], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') + if torch.cuda.is_available() and torch.version.hip: + # do something specific for HIP + parser.add_argument( + '--load-format', + type=str, + default='pt', + choices=['pt'], + help='The format of the model weights to load. ' + '"pt" will load the weights in the pytorch bin format. ') + parser.add_argument( + '--dtype', + type=str, + default='half', + choices=['half', 'float16', 'bfloat16'], + help='data type for model weights and activations. ' + 'The default option is FP16 precision ' + 'Supports FP16 and BF16 ') + elif torch.cuda.is_available() and torch.version.cuda: + # do something specific for CUDA + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + help='The format of the model weights to load. ' + '"auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available. ' + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading. ' + '"dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.') + parser.add_argument( + '--dtype', + type=str, + default=EngineArgs.dtype, + choices=[ + 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' + ], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--max-model-len', type=int, default=None, @@ -171,13 +193,23 @@ def add_cli_args( parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') - # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', 'squeezellm', None], - default=None, - help='Method used to quantize the weights') + if torch.cuda.is_available() and torch.version.hip: + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['squeezellm', None], + default=None, + help='Method used to quantize the weights') + + elif torch.cuda.is_available() and torch.version.cuda: + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['awq', 'squeezellm', None], + default=None, + help='Method used to quantize the weights') return parser @classmethod diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index ee58b8b9074a7..6bff8153e2a5c 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -73,7 +73,9 @@ def initialize_cluster( "Ray is not installed. Please install Ray to use distributed " "serving.") # Connect to a ray cluster. - ray.init(address=ray_address, ignore_reinit_error=True) + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) if not parallel_config.worker_use_ray: # Initialize cluster locally. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index e51bb311decd9..2e042721d9a27 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -113,6 +113,8 @@ def multi_query_kv_attention( attn_bias=input_metadata.attn_bias, p=0.0, scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (torch.cuda.is_available() and torch.version.hip) else None, ) # TODO(woosuk): Unnecessary copy. Optimize. output.copy_(out.view_as(output)) @@ -451,6 +453,8 @@ def multi_query_kv_attention( attn_bias=input_metadata.attn_bias, p=0.0, scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (torch.cuda.is_available() and torch.version.hip) else None, ) # TODO(woosuk): Unnecessary copy. Optimize. output.copy_(out.view_as(output)) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 3d937ba64f9fa..f4d25566cf599 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,14 +1,16 @@ from typing import Type - -from vllm.model_executor.layers.quantization.awq import AWQConfig +import torch from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig _QUANTIZATION_CONFIG_REGISTRY = { - "awq": AWQConfig, "squeezellm": SqueezeLLMConfig, } +if torch.cuda.is_available() and torch.version.cuda: + from vllm.model_executor.layers.quantization.awq import AWQConfig + _QUANTIZATION_CONFIG_REGISTRY["awq"] = AWQConfig + def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in _QUANTIZATION_CONFIG_REGISTRY: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 0ab5819d930aa..5561a9309fc32 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -2,8 +2,12 @@ import torch from torch.nn.parameter import Parameter +if torch.cuda.is_available() and torch.version.hip: + # do something specific for HIP + print("Warning: vLLM does not support AWQ on ROCm.") +elif torch.cuda.is_available() and torch.version.cuda: + from vllm import quantization_ops -from vllm import quantization_ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 61ec8b79b6ddc..be318b2ef205b 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -114,10 +114,19 @@ def apply_weights(self, lookup_table = weights["lookup_table"] out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) - # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, - lookup_table) + if torch.cuda.is_available() and torch.version.hip: + out_float = torch.zeros(out_shape, + device="cuda", + dtype=torch.float) + quantization_ops.squeezellm_gemm(reshaped_x, qweight, out_float, + lookup_table) + out = out_float.to(dtype=torch.float16) + # do something specific for HIP + elif torch.cuda.is_available() and torch.version.cuda: + # NOTE: The output tensor should be zero-initialized. + out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) + quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, + lookup_table) if bias is not None: out = out + bias