diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 46d5e97b176e3..eb7c278435ab9 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1,4 +1,19 @@ -// TODO: add license terms +/* + * Copyright (c) 2024, The vLLM team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include #include #include @@ -594,7 +609,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } else { // warp in context - // iterate across heads #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { diff --git a/csrc/rocm/custom copy.cu b/csrc/rocm/custom copy.cu deleted file mode 100644 index fae1b4fbfbe33..0000000000000 --- a/csrc/rocm/custom copy.cu +++ /dev/null @@ -1,78 +0,0 @@ -#include -#include -#include - -// declare templates for front (cpp) and back (cuda) sides of function: -// template - -void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, - cudaStream_t stream, const int rows_per_block); -void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t rows_per_block) { - auto M = in_a.size(0); - auto K = in_a.size(1); - LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, - at::cuda::getCurrentCUDAStream(), rows_per_block); -} - -void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, - cudaStream_t stream, const int rows_per_block); - -// template -void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t rows_per_block) { - auto M = in_a.size(0); - auto K = in_a.size(1); - // if (N != in_b.numel()) - // throw std::invalid_argument("Size mismatch A.numel(): " + - // std::to_string(in_a.numel()) - // + ", B.numel(): " + - // std::to_string(in_b.numel())); - - // out_c.resize_({N}); - - // call the kernel function... - LLGemm1(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, - at::cuda::getCurrentCUDAStream(), rows_per_block); -} - -void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, - const int N, cudaStream_t stream, const int CuCount); - -void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, - const int64_t N_in, const int64_t CuCount) { - auto M = in_a.size(0); - auto K = in_a.size(1); - int N = N_in; - wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, - at::cuda::getCurrentCUDAStream(), CuCount); -} - -void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, - cudaStream_t stream, const int solidx); - -void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, - const int64_t solidx = 0) { - auto M = in_a.size(0); - auto K = in_a.size(1); - - LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, - at::cuda::getCurrentCUDAStream(), solidx); -} -// instantiate the CPP template for T=float: -// template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor -// out_c); - -void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, - int numAColumns, int numBRows, int numBColumns, int numCRows, - int numCColumns, cudaStream_t stream); - -void MMCustomGPU(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c) { - auto matA_sizes{in_a.sizes()}; - auto matB_sizes{in_b.sizes()}; - auto matO_sizes{out_c.sizes()}; - MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), - out_c.data_ptr(), matA_sizes[0], matA_sizes[1], - matB_sizes[0], matB_sizes[1], matO_sizes[0], matO_sizes[1], - at::cuda::getCurrentCUDAStream()); -} diff --git a/csrc/rocm/custom_kernels copy.cu b/csrc/rocm/custom_kernels copy.cu deleted file mode 100644 index f7dba39bb55ad..0000000000000 --- a/csrc/rocm/custom_kernels copy.cu +++ /dev/null @@ -1,1309 +0,0 @@ -#include -#include -#include -#include -#include "cuda_compat.h" - -#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ - defined(__gfx941__) || defined(__gfx942__)) - #define __HIP__MI300_MI250__ -#endif - -#if defined(NDEBUG) - #undef NDEBUG - #include - #define UNREACHABLE_CODE assert(false); - #define NDEBUG -#else - #define UNREACHABLE_CODE assert(false); -#endif - -template -__device__ __forceinline__ T loadnt(T* addr) { - return __builtin_nontemporal_load(addr); -} - -__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { - auto addr_alias = reinterpret_cast(addr); - auto dat0 = loadnt(addr_alias); - auto dat1 = loadnt(addr_alias + 1); - auto dat2 = loadnt(addr_alias + 2); - auto dat3 = loadnt(addr_alias + 3); - // auto dat0 = *(addr_alias); - // auto dat1 = *(addr_alias+1); - // auto dat2 = *(addr_alias+2); - // auto dat3 = *(addr_alias+3); - return make_float4(dat0, dat1, dat2, dat3); -} - -// TBlock fetches entire rows of A, and entire col of B (K dimension); assume -// N=1 for time being grid is M/A_NUM_ROWS blocks -template -__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, - const int K) { - __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; - const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8; - const int threadid = threadIdx.x; - const int warp = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid / 16; - const int qthreadid = threadid % 16; - float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; - __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; - float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; - float acc[NUM_A_ROWS_PER_BLOCK] = {0.0}; - __half2 acch2; - __half2 oval; - - // As we later use warp shuffle operations, we may have more threads in the - // block than the actual available data, hence the if guard here. - if (threadid * 8 < K) { -#pragma unroll - for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { - // rowA_elem4[i] holds 8 * half numbers seen as a single float4. - rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); - } - } - - colB_elem4x = bf4[threadid * 4 + 0]; - colB_elem4y = bf4[threadid * 4 + 1]; - colB_elem4z = bf4[threadid * 4 + 2]; - colB_elem4w = bf4[threadid * 4 + 3]; - - __half2 Af2; - __half2 Bf2; - float2 S; - - auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); - __half2* ah2lptr; - -#pragma unroll - for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { - // Multiply-add on 8 half. - ah2lptr = Ah2ptr + i * 4; - Af2 = *(ah2lptr); - acch2 = __hmul2(Af2, colB_elem4x); - Af2 = *(ah2lptr + 1); - acch2 = __hfma2(Af2, colB_elem4y, acch2); - Af2 = *(ah2lptr + 2); - acch2 = __hfma2(Af2, colB_elem4z, acch2); - Af2 = *(ah2lptr + 3); - acch2 = __hfma2(Af2, colB_elem4w, acch2); - S = __half22float2(acch2); - - // See comment above concerning the if guard. - if (threadid * 8 < K) { - acc[i] = S.x + S.y; // accumulation on float - } - } - -// all reduce across warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -#pragma unroll - for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { - acc[i] += __shfl_xor(acc[i], mask); - } - } - - // Warp leaders store the data to shared memory. - if (lane < NUM_A_ROWS_PER_BLOCK) { - red_smem[lane][warp] = acc[lane]; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - if (qwarpid < NUM_A_ROWS_PER_BLOCK) { - acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; -#pragma unroll - for (int mask = 16 / 2; mask >= 1; mask /= 2) { - acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); - } - float oval2 = __shfl_xor(acc[qwarpid], 16); - - if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) { - oval = __float22half2_rn(make_float2(acc[qwarpid], oval2)); - c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; - } - } -} - -// define the kernel calling code: -// template -void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, - cudaStream_t stream, const int rows_per_block = 4) { - float4* af4 = reinterpret_cast(in_a); - auto* bf4 = reinterpret_cast<__half2*>(in_b); - auto* c = reinterpret_cast<__half2*>(out_c); - - // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle - // operations. - const int NUM_THREADS = - K * 2 / 16 % WARP_SIZE == 0 - ? K * 2 / 16 - : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); - - int NUM_BLOCKS = M / rows_per_block; - - if (rows_per_block == 2) { - LLGemm1_kernel<2><<>>(af4, bf4, c, K); - } else if (rows_per_block == 4) { - LLGemm1_kernel<4><<>>(af4, bf4, c, K); - } else if (rows_per_block == 8) { - LLGemm1_kernel<8><<>>(af4, bf4, c, K); - } else if (rows_per_block == 16) { - LLGemm1_kernel<16><<>>(af4, bf4, c, K); - } else { - NUM_BLOCKS = M / 4; - LLGemm1_kernel<4><<>>(af4, bf4, c, K); - } - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); -} - -// instantiate the kernel template for T=float: -// template void AddGPUKernel(float *in_a, float *in_b, float *out_c, -// const int M, const int K, cudaStream_t stream); - -const unsigned int TILE_WIDTH = 32; - -// Compute C = A * B -__global__ void matrixMultiplyShared(float* A, float* B, float* C, int numARows, - int numAColumns, int numBRows, - int numBColumns, int numCRows, - int numCColumns) { - __shared__ float sA[TILE_WIDTH][TILE_WIDTH]; // Tile size of 32x32 - __shared__ float sB[TILE_WIDTH][TILE_WIDTH]; - - int Row = blockDim.y * blockIdx.y + threadIdx.y; - int Col = blockDim.x * blockIdx.x + threadIdx.x; - float Cvalue = 0.0; - sA[threadIdx.y][threadIdx.x] = 0.0; - sB[threadIdx.y][threadIdx.x] = 0.0; - - for (int ph = 0; ph < (((numAColumns - 1) / TILE_WIDTH) + 1); ph++) { - if ((Row < numARows) && (threadIdx.x + (ph * TILE_WIDTH)) < numAColumns) { - sA[threadIdx.y][threadIdx.x] = - A[(Row * numAColumns) + threadIdx.x + (ph * TILE_WIDTH)]; - } else { - sA[threadIdx.y][threadIdx.x] = 0.0; - } - if (Col < numBColumns && (threadIdx.y + ph * TILE_WIDTH) < numBRows) { - sB[threadIdx.y][threadIdx.x] = - B[(threadIdx.y + ph * TILE_WIDTH) * numBColumns + Col]; - } else { - sB[threadIdx.y][threadIdx.x] = 0.0; - } - __syncthreads(); - for (int j = 0; j < TILE_WIDTH; ++j) { - Cvalue += sA[threadIdx.y][j] * sB[j][threadIdx.x]; - } - } - if (Row < numCRows && Col < numCColumns) { - C[Row * numCColumns + Col] = Cvalue; - } -} - -void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, - int numAColumns, int numBRows, int numBColumns, int numCRows, - int numCColumns, cudaStream_t stream) { - // Initialize the grid and block dimensions - dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); - dim3 dimGrid((numCColumns / TILE_WIDTH) + 1, (numCRows / TILE_WIDTH) + 1, 1); - //@@ Launch the GPU Kernel here - matrixMultiplyShared<<>>( - in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, - numCColumns); - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); -} - -template -__global__ __launch_bounds__(512) void HGEMV_WFPerRow( - int m, int n, const _Float16* A, int lda, const _Float16* x, _Float16* y) { - int num_row_per_block = CTA / nThreads_per_row; - int row_id = (blockIdx.x * num_row_per_block + threadIdx.y) * MT0; - int inc = (gridDim.x * num_row_per_block) * MT0; - - while (row_id < m) { - float2 sum2[MT0]; - -#pragma unroll - for (int i = 0; i < MT0; ++i) { - sum2[i] = {0.0, 0.0}; - } - - for (int j = threadIdx.x; j < n; j += (nThreads_per_row * MT1)) { - bool is_active = j < n; - if (is_active) { - float2 x2[MT1 >> 1]; -#pragma unroll - for (int offset = 0; offset < MT1; offset += 2) { - x2[offset >> 1] = {x[j + nThreads_per_row * offset], - x[j + nThreads_per_row * (offset + 1)]}; - } - float2 a2[MT0][MT1 >> 1]; -#pragma unroll - for (int i = 0; i < MT0; i++) { -#pragma unroll - for (int offset = 0; offset < MT1; offset += 2) { - a2[i][offset >> 1] = { - A[(row_id + i) * n + j + nThreads_per_row * offset], - A[(row_id + i) * n + j + nThreads_per_row * (offset + 1)]}; - } - } - -#pragma unroll - for (int i = 0; i < MT0; i++) { -#pragma unroll - for (int offset = 0; offset < (MT1 >> 1); offset++) { - sum2[i] += a2[i][offset] * x2[offset]; - } - } - } - } - float sum[MT0]; -#pragma unroll - for (int i = 0; i < MT0; i++) { - sum[i] = sum2[i].x + sum2[i].y; - } - -#pragma unroll - for (int i = 0; i < MT0; i++) { -#pragma unroll - for (int offset = nThreads_per_row >> 1; offset >= 1; - offset = offset >> 1) { - sum[i] += __shfl_down(sum[i], offset, nThreads_per_row); - } - } - if (threadIdx.x == 0) { -#pragma unroll - for (int i = 0; i < MT0; i++) { - y[row_id + i] = sum[i]; - } - } - row_id += inc; - } -} - -void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, - cudaStream_t stream, const int solidx = 0) { - // m -> M, n-> K - dim3 grid(1024); - dim3 block(64, 8); - if (solidx == 0) { - HGEMV_WFPerRow<64, 512, 4, 8><<>>( - M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b), - reinterpret_cast<_Float16*>(out_c)); - } else if (solidx == 1) { - HGEMV_WFPerRow<64, 512, 2, 8><<>>( - M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b), - reinterpret_cast<_Float16*>(out_c)); - } else if (solidx == 2) { - HGEMV_WFPerRow<64, 512, 1, 8><<>>( - M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b), - reinterpret_cast<_Float16*>(out_c)); - } else { - HGEMV_WFPerRow<64, 512, 4, 8><<>>( - M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b), - reinterpret_cast<_Float16*>(out_c)); - } - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); -} - -///////////////////////////////////////////// - -#define DTYPE half - -__device__ __forceinline__ int mindiv(int N, int div1, int div2) { - int nPrRnd = div1 * div2; - int rnds0 = N / nPrRnd; - nPrRnd -= div1 * 3; - int rnds3 = N / nPrRnd; - nPrRnd -= div1; - int rnds4 = N / nPrRnd; - nPrRnd -= div1; - int rnds5 = N / nPrRnd; - nPrRnd -= div1; - int rnds6 = N / nPrRnd; - nPrRnd -= div1; - int rnds7 = N / nPrRnd; - nPrRnd -= div1; - int rnds8 = N / nPrRnd; - nPrRnd -= div1; - int rnds9 = N / nPrRnd; - nPrRnd -= div1; - int rtn = div2; - if (rnds0 == rnds3) rtn = div2 - 3; - if (rnds0 == rnds4) rtn = div2 - 4; - if (rnds0 == rnds5) rtn = div2 - 5; - if (rnds0 == rnds6) rtn = div2 - 6; - if (rnds0 == rnds7) rtn = div2 - 7; - if (rnds0 == rnds8) rtn = div2 - 8; - if (rnds0 == rnds9) rtn = div2 - 9; - return rtn; -} - -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support -// This version targets cases where A[] fits LDS capacity -template -__global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { - using half8 = - __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - half8 h8; - }; - - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - // uint32_t commitColumn[YTILE]; - // for (uint32_t i = 0; i < YTILE; i++) { - // commitColumn[i] = 1; - //} - - // It's worth trying to load-balance... - int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); - - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - // if (n < N && (n + YTILE) >= N) { - // uint32_t startColumn = N - YTILE; - // for (uint32_t i = 0; i < (n - startColumn); i++) { - // commitColumn[i] = 0; - // } - // n = startColumn; - //} - - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32 * 1024); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - - if (k_in >= min(K * M, 32 * 1024)) break; - - //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - if (threadIdx.y >= _WvPrGrp) return; - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- - while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; - - bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { - // Fetch the weight matrix from memory! - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); - } - - // Fetch activation matrix from either just LDS or from both LDS / memory - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m = 0; m < M; m++) { - // if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); - // else - // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); - } - } - - // Do the matrix multiplication in interleaved manner - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! - #pragma unroll - for (uint32_t m = 0; m < M; m++) { - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); - } - } - } - } - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - C[n + i + m * N] = __float2half(sum[m][i]); - } - } - } - - n += CuCount * _WvPrGrp * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - // if (n < N && (n + YTILE) >= N) { - // uint32_t startColumn = N - YTILE; - // for (uint32_t i = 0; i < (n - startColumn); i++) { - // commitColumn[i] = 0; - // } - // n = startColumn; - //} - } -} -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template -__global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { - UNREACHABLE_CODE -} -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support -// This version targets cases where A[] marginally exceeds LDS capacity -template -__global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSpltK_hf_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { - using half8 = - __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - half8 h8; - }; - - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) { - commitColumn[i] = 1; - } - - // It's worth trying to load-balance... - int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); - - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32 * 1024); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - - if (k_in >= min(K * M, 32 * 1024)) break; - - //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - if (threadIdx.y >= _WvPrGrp) return; - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- - while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; - - bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - bigType bigB8[UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { - // Fetch the weight matrix from memory! - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); - } - - // Fetch activation matrix from either just LDS or from both LDS / memory - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m = 0; m < M; m++) { - if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); - } - } - - // Do the matrix multiplication in interleaved manner - #pragma unroll - for (uint32_t m = 0; m < M; m++) { - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); - } - } - } - } - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - - if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - } - } - } - - n += CuCount * _WvPrGrp * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - } -} - -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template -__global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { - UNREACHABLE_CODE -} -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support -// This version targets big A[] cases, where it is much larger than LDS capacity -template -__global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { - using half8 = - __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - half8 h8; - }; - - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) { - commitColumn[i] = 1; - } - - // It's worth trying to load-balance... - int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); - if (threadIdx.y >= _WvPrGrp) return; - - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - #define PCML - #ifndef PCML - for (uint32_t k = 0; k < min(K * M, 32 * 1024); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - - if (k_in >= min(K * M, 32 * 1024)) break; - - //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - #endif - - #define TUC (THRDS * UNRL * A_CHUNK) - uint32_t kBase = 0; - // find biggest k size that fits in LDS - uint32_t kFit = (32 * 1024) / M; - // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple - // of TUC - kFit = (kFit % TUC == 0) - ? kFit - : (kFit - kFit % TUC); // round up to multiple of TUC - // if (kFit == 0) kFit = TUC; - kFit = min(kFit, K); - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- - #ifdef PCML - int YW = (YTILE * _WvPrGrp); - uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); - while (n < Nrndp) { - #else - while (n < N) { - #endif - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; - - bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - bigType bigB1[UNRL]; - bigType bigB2[UNRL]; - bigType bigB3[UNRL]; - bigType bigB4[UNRL]; - bigType bigB5[UNRL]; - bigType bigB6[UNRL]; - bigType bigB7[UNRL]; - bigType bigB8[UNRL]; - bigType bigB9[UNRL]; - bigType bigB10[UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { - #ifdef PCML - if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS - if (k1 != 0) kBase += kFit; - __syncthreads(); - for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { - uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (kBase + kOff >= K) break; - if (kOff >= kFit) break; - for (uint32_t m = 0; m < M; m++) { - uint32_t k_in = kBase + m * K + kOff; - uint32_t k_ot = m * kFit + kOff; - *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); - } - } - __syncthreads(); - } - if (n >= N) continue; - #endif - - // Fetch the weight matrix from memory! - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); - } - - // Fetch activation matrix from either just LDS or from both LDS / memory - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m = 0; m < M; m++) { - #ifdef PCML - bigA[m][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * m]))); - #else - if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); - #endif - } - } - - // Do the matrix multiplication in interleaved manner - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - #pragma unroll - for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); - } - } - } - } - - #ifdef PCML - if (n >= N) { - n += CuCount * _WvPrGrp * YTILE; - kBase = 0; - continue; - } - #endif - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - - if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - } - } - } - - n += CuCount * _WvPrGrp * YTILE; - kBase = 0; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - } -} -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support -template -__global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { - UNREACHABLE_CODE -} -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, - const int K_in, const int N_in, cudaStream_t stream, - const int CuCount = 0) { - dim3 grid(CuCount); - half* af4 = reinterpret_cast(in_a); - const half* bf4 = reinterpret_cast(in_b); - auto* c = reinterpret_cast(out_c); - -#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ - _N) \ - { \ - dim3 block(64, _WvPrGrp); \ - /*wvSpltK_hf:*/ \ - if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ - wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ - <<>>(K_in, M_in, af4, bf4, c, CuCount); \ - } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ - wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ - <<>>(K_in, M_in, af4, bf4, c, CuCount); \ - } else { \ - wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ - <<>>(K_in, M_in, af4, bf4, c, CuCount); \ - } \ - } - - switch (N_in) { - case 1: - WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 - break; - case 2: - WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 - break; - case 3: - WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 - break; - case 4: - WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 - break; - default: - throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + - "," + std::to_string(K_in) + "," + - std::to_string(N_in)); - } - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) { - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); - } -} \ No newline at end of file diff --git a/csrc/rocm/fused_kernels copy.cu b/csrc/rocm/fused_kernels copy.cu deleted file mode 100644 index 4f3eea4562949..0000000000000 --- a/csrc/rocm/fused_kernels copy.cu +++ /dev/null @@ -1,195 +0,0 @@ -#include -#include -#include -#include - -constexpr int WARP_SIZE = 64; - -template -__device__ __forceinline__ T silu(const T& x) { - // x * sigmoid(x) - return (T)(((float)x) / (1.0f + expf((float)-x))); -} - -template -__device__ __forceinline__ T loadnt(T* addr) { - return __builtin_nontemporal_load(addr); -} - -__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { - auto addr_alias = reinterpret_cast(addr); - auto dat0 = loadnt(addr_alias); - auto dat1 = loadnt(addr_alias + 1); - auto dat2 = loadnt(addr_alias + 2); - auto dat3 = loadnt(addr_alias + 3); - // auto dat0 = *(addr_alias); - // auto dat1 = *(addr_alias+1); - // auto dat2 = *(addr_alias+2); - // auto dat3 = *(addr_alias+3); - return make_float4(dat0, dat1, dat2, dat3); -} - -// TBlock fetches entire rows of A, and entire col of B (K dimension); assume -// N=1 for time being grid is M/A_NUM_ROWS blocks -template -__global__ void LLGemm_Silu_kernel(float4* af4, __half2* bf4, _Float16* c, - const int d) { - __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; - const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 * blockDim.x; - const int row_addr_d = row_addr + d * blockDim.x; - // int row_addr_1 = row_addr + CUDA_NUM_THREADS; - // int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; - // int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; - const int threadid = threadIdx.x; - const int warp = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid / 16; - const int qthreadid = threadid % 16; - float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; - // float4 colB_elem4; - __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; - float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; - float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; - __half2 acch2; - __half2 oval; - - // rowA_elem4 = af4[row_addr + threadid]; - //__syncthreads(); - // rowA_elem4_1 = af4[row_addr_1 + threadid]; - // rowA_elem4_2 = af4[row_addr_2 + threadid]; - // rowA_elem4_3 = af4[row_addr_3 + threadid]; -#pragma unroll - for (int i = 0; i < NUM_A_ROWS_PER_BLOCK / 2; i++) { - rowA_elem4[2 * i] = load_ntmprl(&af4[row_addr + i * blockDim.x + threadid]); - rowA_elem4[2 * i + 1] = - load_ntmprl(&af4[row_addr_d + i * blockDim.x + threadid]); - // rowA_elem4[i] = af4[row_addr + i*blockDim.x + threadid]; - //__syncthreads(); - } - colB_elem4x = bf4[threadid * 4 + 0]; - colB_elem4y = bf4[threadid * 4 + 1]; - colB_elem4z = bf4[threadid * 4 + 2]; - colB_elem4w = bf4[threadid * 4 + 3]; - - // __syncthreads(); - __half2 Af2; - __half2 Bf2; - float2 S; - // auto Bh2ptr = reinterpret_cast<__half2 *>(&colB_elem4); - // auto Bf2x = *Bh2ptr; - // auto Bf2y = *(Bh2ptr+1); - // auto Bf2z = *(Bh2ptr+2); - // auto Bf2w = *(Bh2ptr+3); - auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); - __half2* ah2lptr; -#pragma unroll - for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { - ah2lptr = Ah2ptr + i * 4; - Af2 = *(ah2lptr); - acch2 = __hmul2(Af2, colB_elem4x); - Af2 = *(ah2lptr + 1); - acch2 = __hfma2(Af2, colB_elem4y, acch2); - Af2 = *(ah2lptr + 2); - acch2 = __hfma2(Af2, colB_elem4z, acch2); - Af2 = *(ah2lptr + 3); - acch2 = __hfma2(Af2, colB_elem4w, acch2); - S = __half22float2(acch2); - acc[i] = S.x + S.y; - } - -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -#pragma unroll - for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { - acc[i] += __shfl_xor(acc[i], mask); - } - } - - // Warp leaders store the data to shared memory. - // if (lane == 0) { - // #pragma unroll - // for (int i=0; i= 1; mask /= 2) { - // #pragma unroll - // for (int i=0; i -void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, - cudaStream_t stream, const int rows_per_block = 4) { - float4* af4 = reinterpret_cast(in_a); - auto* bf4 = reinterpret_cast<__half2*>(in_b); - auto* c = reinterpret_cast<_Float16*>(out_c); - const int d = M / 2; - const int NUM_THREADS = K * 2 / 16; - int NUM_BLOCKS = M / rows_per_block; - if (rows_per_block == 2) { - LLGemm_Silu_kernel<2> - <<>>(af4, bf4, c, d); - } else if (rows_per_block == 4) { - LLGemm_Silu_kernel<4> - <<>>(af4, bf4, c, d); - } else if (rows_per_block == 8) { - LLGemm_Silu_kernel<8> - <<>>(af4, bf4, c, d); - } else if (rows_per_block == 16) { - LLGemm_Silu_kernel<16> - <<>>(af4, bf4, c, d); - } else { - NUM_BLOCKS = M / 4; - LLGemm_Silu_kernel<4> - <<>>(af4, bf4, c, d); - } - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); -} diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 6b962c8b8f73a..18c72f937f90a 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -19,5 +19,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& context_lens, int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - double k_scale, double v_scale); + const std::string& kv_cache_dtype, double k_scale, + double v_scale); diff --git a/tests/kernels/test_attention_custom.py b/tests/kernels/test_attention_custom.py index 5efaee2e200ed..65cfbb9d9872e 100644 --- a/tests/kernels/test_attention_custom.py +++ b/tests/kernels/test_attention_custom.py @@ -164,79 +164,27 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - kv_scale = 1.0 + k_scale = v_scale = 1.0 # Call the paged attention kernel. output = torch.empty_like(query) - if version == "v1": - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - kv_scale, - ) - elif version == "v2" or version == "custom": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) - assert PARTITION_SIZE % block_size == 0 - num_seqs, num_heads, head_size = output.shape - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - ) - max_logits = torch.empty_like(exp_sums) - if version == "v2": - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - kv_scale, - ) - elif version == "custom": - ops.paged_attention_custom( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - ) - else: - raise AssertionError(f"Unknown version: {version}") + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_rocm(output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, + block_tables, context_lens, block_size, + max_context_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale) # Run the reference implementation. if kv_cache_dtype == "fp8": diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 16236b7ba6acb..88c3ef6ece511 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -21,6 +21,7 @@ marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.scalar_type import scalar_types +from vllm.utils import is_hip def torch_moe(a, w1, w2, score, topk): @@ -154,6 +155,8 @@ def compute_max_diff(output, output_ref): torch.abs(output_ref)) +@pytest.mark.skipif(is_hip(), + reason="Make this test work with MoE padding on HIP") @pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [128, 2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 1024, 512]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ef0fdc1bbdfe6..c0e72cfd55b6f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -139,30 +139,6 @@ def paged_attention_v2( blocksparse_block_size, blocksparse_head_sliding_step) -def paged_attention_rocm( - out: torch.Tensor, - exp_sum: torch.Tensor, - max_logits: torch.Tensor, - tmp_out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - block_size: int, - max_seq_len: int, - alibi_slopes: Optional[torch.Tensor], - kv_cache_dtype: str, -) -> None: - torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, - scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, - kv_cache_dtype) - - # pos encoding ops def rotary_embedding( positions: torch.Tensor, @@ -956,10 +932,11 @@ def paged_attention_rocm( k_scale: int, v_scale: int, ) -> None: - torch.ops._rocm_C.paged_attention( - out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, v_scale) + torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, + scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale) def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b6d725e9ce770..9dd74081390a7 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -545,8 +545,8 @@ def forward( block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads use_custom = use_rocm_custom_paged_attention( - decode_query.dtype, head_size, block_size, self.kv_cache_dtype, - gqa_ratio, decode_meta.max_decode_seq_len) + decode_query.dtype, head_size, block_size, gqa_ratio, + decode_meta.max_decode_seq_len) if use_custom: max_seq_len = decode_meta.max_decode_seq_len max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // @@ -625,8 +625,8 @@ def _sdpa_attention( def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, - block_size: int, kv_cache_dtype: str, - gqa_ratio: int, max_seq_len: int) -> bool: + block_size: int, gqa_ratio: int, + max_seq_len: int) -> bool: # rocm custom page attention not support on navi (gfx1*) return (envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN and not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) diff --git a/vllm/config.py b/vllm/config.py index eca195b324909..1644bc618cb7a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -281,7 +281,8 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = [ - "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8" + "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", + "fbgemm_fp8" ] optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", diff --git a/vllm/entrypoints/fast_sync_llm.py b/vllm/entrypoints/fast_sync_llm.py index a36d109af5ec1..c948fc97feeb9 100644 --- a/vllm/entrypoints/fast_sync_llm.py +++ b/vllm/entrypoints/fast_sync_llm.py @@ -78,7 +78,7 @@ def run_engine(self): self.result_queue.put(("Ready", None, None)) prompt_lens = {} - tokens = {} + tokens = {} # type: ignore log_interval = 100 poll_interval = envs.VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS try: diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 20c96fbcaed90..e3d9f6929adfc 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -87,7 +87,7 @@ def apply_fp8_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, - out_dtype: torch.dtype, + out_dtype: Optional[torch.dtype] = None, input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -98,6 +98,9 @@ def apply_fp8_linear( # If dynamic, layer.input_scale is None and x_scale computed from x. # If static, layer.input_scale is scalar and x_scale is input_scale. + if out_dtype is None: + out_dtype = input.dtype + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant( @@ -168,11 +171,13 @@ def apply_fp8_linear( # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place - output, _ = torch._scaled_mm(qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32) + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + if type(output) is tuple and len(output) == 2: + output = output[0] # Unpad (undo num_token_padding) output = torch.narrow(output, 0, 0, input.shape[0]) x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])