-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support #10995
[Kernel]: Cutlass 2:4 Sparsity + FP8/Int8 Quant Support #10995
Conversation
Removed cmake check for cusparseLt, needs to be reverted when the cmake issue is resolved.
…tils instead of our decompressor
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 | ||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 | ||
|
||
print("in test") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove cruft
|
||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); | ||
int32_t version_num = test_get_sm_version_num(); | ||
// Hopper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: what's this comment for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe for a future PR but there should be more tests here, test more shapes, there should be and opcheck
test (see test_cutlass_support_opcheck
), a cuda graph test (see test_cutlass_cuda_graph
). Use vllm/tests/kernels/test_cutlass.py
as inspiration (with the exception of the azp stuff I assume)
benchmarks/benchmark_throughput.py
Outdated
@@ -361,7 +361,8 @@ def main(args: argparse.Namespace): | |||
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length. | |||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " | |||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " | |||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s") | |||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s, " | |||
f"{total_num_tokens=} | {total_output_tokens=}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like debug cruft and should be reverted if so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
csrc/cutlass_extensions/common.hpp
Outdated
inline uint32_t next_pow_2(uint32_t const num) { | ||
if (num <= 1) return num; | ||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you put this in csrc/core/math.hpp
? @SageMoore is adding similar utilities to that file in #10867
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
csrc/cutlass_extensions/common.hpp
Outdated
#define CUDA_CHECK(status) \ | ||
{ \ | ||
cudaError_t error = status; \ | ||
if (error != cudaSuccess) { \ | ||
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ | ||
<< " at line: " << __LINE__ << std::endl; \ | ||
exit(EXIT_FAILURE); \ | ||
} \ | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should throw an exception here, and it should behave generally the same way that CUTLASS_CHECK does.
(I do like the line number reporting though, so it would be nice if you could add it to both)
#define CUDA_CHECK(status) \ | |
{ \ | |
cudaError_t error = status; \ | |
if (error != cudaSuccess) { \ | |
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ | |
<< " at line: " << __LINE__ << std::endl; \ | |
exit(EXIT_FAILURE); \ | |
} \ | |
} | |
#define CUDA_CHECK(status) \ | |
{ \ | |
TORCH_CHECK(status == cudaSuccess, \ | |
cudaGetErrorString(status)) \ | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
#include <cudaTypedefs.h> | ||
|
||
#include <torch/all.h> | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
#include <iostream> | ||
#include <sstream> | ||
#include <vector> | ||
|
||
#include "cutlass/cutlass.h" | ||
|
||
#include "cute/tensor.hpp" | ||
#include "cute/atom/mma_atom.hpp" | ||
#include "cutlass/numeric_types.h" | ||
#include "cutlass/numeric_conversion.h" | ||
#include "cutlass/detail/dependent_false.hpp" | ||
|
||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" | ||
#include "cutlass_extensions/common.hpp" | ||
|
||
#include "cutlass/transform/device/transform_universal_adapter.hpp" | ||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" | ||
|
||
#include "cutlass/epilogue/collective/default_epilogue.hpp" | ||
#include "cutlass/epilogue/thread/linear_combination.h" | ||
#include "cutlass/gemm/collective/collective_builder.hpp" | ||
#include "cutlass/gemm/device/gemm_universal_adapter.h" | ||
#include "cutlass/gemm/kernel/gemm_universal.hpp" | ||
|
||
#include <iostream> | ||
|
||
#include "cutlass/cutlass.h" | ||
|
||
#include "cutlass/tensor_ref.h" | ||
#include "cutlass/epilogue/collective/collective_builder.hpp" | ||
#include "cutlass/gemm/dispatch_policy.hpp" | ||
|
||
#include "cutlass/util/host_tensor.h" | ||
#include "cutlass/util/packed_stride.hpp" | ||
|
||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" | ||
#include "sparse_scaled_mm_c3x.cuh" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please clean up these includes. I see some duplicates. Could you try to minimize the number of includes? I.E. no duplicates, and nothing that's unnecessary?
Also please turn clang-format off for the includes, as CUTLASS headers don't tolerate reordering.
// clang-format will break include orders
// clang-format off
#include "your.h"
#include "includes.h"
#include "here.h"
// clang-format on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be pared down further.
For example:
"cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
already includes "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
and most of our CUTLASS kernels don't interact directly with the code in broadcast_load_epilogue_c3x.hpp
so they should only include scaled_mm_epilogues_c3x.hpp
.
However this sparsify_and_compress
kernel doesn't use any epilogues at all so it shouldn't include either of them.
Could you take another look at these includes and the includes in your other kernels as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. The CUTLASS's CompressorUtility necessitates that a Gemm be defined with all operand types, schedules, etc with an epilogue, albeit the default. I had previously used my default gemm config with ScaledEpilogue for this Gemm but per this review, I replaced that with an on-the-spot Gemm kernel setup similar to the examples provided in CUTLASS. I am also mentioning this in a comment in the code now.
// Just a dummy value | ||
int32_t n = 128; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you expand on this comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was just needed to instantiate a problem shape to use the compressor utility in CUTLASS. I replaced it with 1 in the problem shape directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please put this in a comment in the code so that it is documented there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// Check for strides and alignment | ||
TORCH_CHECK(a.stride(1) == 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any requirement for the divisibility of a.stride(0)
? Do we test odd values of m
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. Since we're doing column-major output in the kernels, there's no requirement. For row-major output, the batch size has to be a multiple of 8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought this was the weight matrix, so batch isn't relevant here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, my bad for misunderstanding. The intermediate dimension of the matmul should be divisible by 4 to be able to follow the 2:4 sparsity. So a.stride(0) % 4 == 0
must hold. I added a check for this divisibility.
Epilogue functions can be defined to post-process the output before it is | ||
written to GPU memory. | ||
Epilogues must contain a public type named EVTCompute of type Sm90EVT, | ||
as well as a static prepare_args function that constructs an | ||
EVTCompute::Arguments struct. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this comment is epilogue-specific and the epilogues are not defined in this file, I think this comment should be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
vllm/_custom_ops.py
Outdated
def cutlass_compress_entry(a: torch.Tensor) \ | ||
-> Tuple[torch.Tensor, torch.Tensor]: | ||
assert (a.dtype in [ | ||
torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16 | ||
]) | ||
|
||
# e.dtype: torch.uint8 so elemsPerElemE = 8b / 2b_per_nz = 4 | ||
elemsPerElemE = 4 | ||
|
||
m = a.shape[0] | ||
k = a.shape[1] | ||
a_compressed = torch.empty((m, k // 2), dtype=a.dtype, device=a.device) | ||
e = torch.empty((m, k // 2 // elemsPerElemE), | ||
dtype=torch.uint8, | ||
device=a.device) | ||
|
||
if not (torch.ops._C.cutlass_compress_entry(a_compressed, e, a)): | ||
raise ValueError | ||
|
||
return a_compressed, e | ||
|
||
|
||
def cutlass_scaled_sparse_mm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add high-level comments for what these are doing? In particular could you describe what e
is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me now, thanks for the hard work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM too, just left a few very minor refactor/comment nits. Thanks for the hardwork and iterations!
csrc/torch_bindings.cpp
Outdated
ops.def( | ||
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a," | ||
" Tensor b," | ||
" Tensor e, Tensor a_scales," |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you update argument naming to match, i.e. bt_nzs
and bt_meta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
using ElementAB = typename Gemm::ElementAB; | ||
using ElementD = typename Gemm::ElementD; | ||
|
||
// Interface stride expected from the argument a (will get transposed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you elaborate on this a bit, i.e. add something about the fact that we compute C^t = B^t @ A^t
but we assume B is transposed before compressing hence the bt_<x>
naming
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
auto layout_A = make_cute_layout<StrideA>(a, "A"); | ||
auto layout_D = make_cute_layout<StrideD>(out, "D"); | ||
|
||
auto stride_At = layout_A.stride(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you add a comment here explaining why At
is the same stride as A
for cutlass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
using GemmKernel = typename Gemm::GemmKernel; | ||
typename GemmKernel::ProblemShape prob_shape{ | ||
(int)bt_nzs.size(0), (int)size<0>(layout_A), (int)size<1>(layout_A), 1}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we should avoid c-style casts for consistency (use static_cast
here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
csrc/torch_bindings.cpp
Outdated
|
||
// CUTLASS sparse matrix compressor | ||
ops.def( | ||
"cutlass_sparse_compress_entry(Tensor! a_compressed, Tensor! e," |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe update this to match the argument naming for cutlass_scaled_sparse_mm
i.e. Tensor! a_nzs, Tensor! a_meta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
/// Make A structured sparse by replacing elements with 0 and compress it | ||
template <typename ElementA_, typename ElementAcc_> | ||
bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe update this to match the argument naming for cutlass_scaled_sparse_mm
i.e. Tensor! a_nzs, Tensor! a_meta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
csrc/cutlass_extensions/common.hpp
Outdated
* Helper function for checking CUTLASS errors | ||
*/ | ||
#define CUTLASS_CHECK(status) \ | ||
{ \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe extract status first (like below) so this macro can directly wrap expressions like function calls and not double-evaluate them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
CMakeLists.txt
Outdated
GIT_PROGRESS TRUE | ||
|
||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. | ||
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. | ||
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE | ||
GIT_SHALLOW TRUE | ||
# GIT_SHALLOW FALSE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be uncommented as FALSE now?
# GIT_SHALLOW FALSE | |
GIT_SHALLOW FALSE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sure. It's also the default I think but better be explicit as you said.
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) | ||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
future work: what about per-channel/per-token scales?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. We can also use that for benchmarking. I put this here only because it's similar to the dense benchmarking script.
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
return 90 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth leaving a note that this is due to cutlass 3.x kernel restrictions since we do have fp16+int8 support here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
… tests for >90 sm capability
…#10995) Co-authored-by: Faraz Shahsavan <[email protected]> Co-authored-by: ilmarkov <[email protected]> Co-authored-by: Rahul Tuli <[email protected]> Co-authored-by: [email protected] <[email protected]> Signed-off-by: Sage Moore <[email protected]>
…-project#10995)" This reverts commit 60508ff.
Summary
From Neural Magic