-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Bugfix][Build/CI] Fix sparse CUTLASS compilation on CUDA [12.0, 12.2)
Signed-off-by: Tyler Michael Smith <[email protected]>
- Loading branch information
1 parent
ca5f54a
commit 4b8652d
Showing
10 changed files
with
176 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#include <cub/cub.cuh> | ||
#include <cuda_runtime.h> | ||
|
||
// Custom reduce implementation | ||
namespace vllm { | ||
|
||
namespace detail { | ||
|
||
template <typename T> | ||
__inline__ __device__ T _sum(T a, T b) { | ||
return a + b; | ||
} | ||
|
||
} // namespace detail | ||
|
||
template <typename T> | ||
using ReduceFnType = T (*)(T, T); | ||
|
||
template <typename T, int numLanes = 32> | ||
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) { | ||
for (int mask = numLanes / 2; mask > 0; mask /= 2) | ||
val = fn(val, __shfl_xor_sync(0xffffffff, val, mask)); | ||
return val; | ||
} | ||
|
||
template <typename T, int maxBlockSize = 1024> | ||
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) { | ||
static __shared__ T shared[32]; // Assuming a warp size of 32 | ||
int lane = threadIdx.x % 32; | ||
int wid = threadIdx.x / 32; | ||
|
||
val = warpReduce<T>(val, fn); | ||
|
||
if (lane == 0) shared[wid] = val; | ||
|
||
__syncthreads(); | ||
|
||
val = (threadIdx.x < blockDim.x / 32) ? shared[lane] : T(0); | ||
if (wid == 0) val = warpReduce<T>(val, fn); | ||
|
||
return val; | ||
} | ||
|
||
template <typename T, int maxBlockSize = 1024> | ||
__inline__ __device__ T blockReduceSum(T val) { | ||
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>); | ||
} | ||
|
||
} // namespace vllm | ||
|
||
// Kernel using custom reduce | ||
template <typename T> | ||
__global__ void custom_reduce_kernel(const T* input, T* output, int n) { | ||
T sum = 0; | ||
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += blockDim.x * gridDim.x) { | ||
sum += input[i]; | ||
} | ||
sum = vllm::blockReduceSum<T>(sum); | ||
if (threadIdx.x == 0) { | ||
atomicAdd(output, sum); | ||
} | ||
} | ||
|
||
// Kernel using CUB reduce | ||
template <typename T> | ||
__global__ void cub_reduce_kernel(const T* input, T* output, int n) { | ||
typedef cub::BlockReduce<T, 256> BlockReduce; | ||
__shared__ typename BlockReduce::TempStorage temp_storage; | ||
|
||
T sum = 0; | ||
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n; i += blockDim.x * gridDim.x) { | ||
sum += input[i]; | ||
} | ||
|
||
sum = BlockReduce(temp_storage).Sum(sum); | ||
|
||
if (threadIdx.x == 0) { | ||
atomicAdd(output, sum); | ||
} | ||
} | ||
|
||
// Helper function to launch custom reduce kernel | ||
template <typename T> | ||
void custom_reduce(const T* d_in, T* d_out, int n, cudaStream_t stream) { | ||
int block_size = 256; | ||
int grid_size = (n + block_size - 1) / block_size; | ||
custom_reduce_kernel<<<grid_size, block_size, 0, stream>>>(d_in, d_out, n); | ||
} | ||
|
||
// Helper function to launch CUB reduce kernel | ||
template <typename T> | ||
void cub_reduce(const T* d_in, T* d_out, int n, cudaStream_t stream) { | ||
int block_size = 256; | ||
int grid_size = (n + block_size - 1) / block_size; | ||
cub_reduce_kernel<<<grid_size, block_size, 0, stream>>>(d_in, d_out, n); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters