diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index d32bbca9b65d6..c69e87999ae71 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -3,9 +3,9 @@ /* This file defines custom epilogues for fusing channel scales, token scales, bias, and activation zero-points onto a GEMM operation using the - CUTLASS 3.x API, for pre sm90 (Hopper) NVIDIA GPUs. + CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. - Epilogues must contain a public type named EVTCompute of type Sm90EVT, + Epilogues must contain a public type named EVTCompute of type Sm80EVT, as well as a static prepare_args function that constructs an EVTCompute::Arguments struct. */ diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 40a2c528cc136..1d8d03ece9da4 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -82,7 +82,7 @@ {% for impl_config in impl_configs %} {% set t = impl_config.types -%} {% set type_sig = gen_type_sig(t) -%} - if (args.btype == {{VLLMScalarTypeTag[t.b]}} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} && a_type == {{TorchTypeTag[t.a]}} && out_type == {{TorchTypeTag[t.out]}} && {%if t.b_group_scale != void -%} @@ -105,7 +105,7 @@ TORCH_CHECK_NOT_IMPLEMENTED( false, "machete_mm(..) is not implemented for " "a_type=", args.A.scalar_type(), - ", b_type=", args.btype.str(), + ", b_type=", args.b_type.str(), ", out_type=", out_type, ", with_group_scale_type=", maybe_g_scales_type ? toString(*maybe_g_scales_type) : "None", @@ -231,7 +231,7 @@ TORCH_CHECK_NOT_IMPLEMENTED(false, "prepack_B_dispatch(..) is not implemented for " "atype = ", args.a_type, - ", btype = ", args.b_type.str(), + ", b_type = ", args.b_type.str(), ", with_group_scales_type= ", args.maybe_group_scales_type ? toString(*args.maybe_group_scales_type) : "None"); } diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index 5b60263266952..4b0da5b303e0c 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -12,7 +12,7 @@ namespace machete { struct MMArgs { torch::Tensor const& A; torch::Tensor const& B; - vllm::ScalarType const& btype; + vllm::ScalarType const& b_type; c10::optional const& maybe_out_type; c10::optional const& maybe_group_scales; c10::optional const& maybe_group_zeros; diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index bdf841e9aa444..da2c2fb0d3e77 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -9,13 +9,13 @@ namespace machete { using namespace vllm; std::vector supported_schedules( - at::ScalarType a_type, int64_t btype_id, + at::ScalarType a_type, int64_t b_type_id, c10::optional maybe_group_scales_type, c10::optional maybe_group_zeros_type, c10::optional maybe_channel_scales_type, c10::optional maybe_token_scales_type, c10::optional maybe_out_type) { - ScalarType const b_type = ScalarType::from_id(btype_id); + ScalarType const b_type = ScalarType::from_id(b_type_id); return supported_schedules_dispatch({ .a_type = a_type, .b_type = b_type, @@ -28,7 +28,7 @@ std::vector supported_schedules( } torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, - int64_t btype_id, + int64_t b_type_id, c10::optional const& maybe_out_type, c10::optional const& maybe_group_scales, c10::optional const& maybe_group_zeros, @@ -36,10 +36,10 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, c10::optional const& maybe_channel_scales, c10::optional const& maybe_token_scales, c10::optional maybe_schedule) { - ScalarType const b_type = ScalarType::from_id(btype_id); + ScalarType const b_type = ScalarType::from_id(b_type_id); return mm_dispatch({.A = A, .B = B, - .btype = b_type, + .b_type = b_type, .maybe_out_type = maybe_out_type, .maybe_group_scales = maybe_group_scales, .maybe_group_zeros = maybe_group_zeros, @@ -50,9 +50,9 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, } torch::Tensor prepack_B( - torch::Tensor const& B, at::ScalarType const& a_type, int64_t btype_id, + torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id, c10::optional const& maybe_group_scales_type) { - ScalarType const b_type = ScalarType::from_id(btype_id); + ScalarType const b_type = ScalarType::from_id(b_type_id); return prepack_B_dispatch( {.B = B, .a_type = a_type, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 776027176ca14..dbd6d631e12d5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -200,7 +200,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "machete_mm(" " Tensor A," " Tensor B," - " int btype," + " int b_type," " ScalarType? out_type," " Tensor? group_scales," " Tensor? group_zeros,"