From 8172f4e41ed3fa021e20252907566d9e07c104a9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 9 Aug 2024 04:21:23 +0000 Subject: [PATCH] review comments --- CMakeLists.txt | 1 - csrc/cuda_utils.h | 2 +- csrc/cutlass_extensions/cute_utils.cuh | 10 +++++ csrc/cutlass_extensions/torch_utils.hpp | 24 +++-------- .../vllm_numeric_conversion.cuh | 9 ++-- csrc/ops.h | 2 - csrc/quantization/machete/generate.py | 10 ++--- .../machete/machete_mm_launcher.cuh | 42 +++++++++---------- csrc/quantization/machete/machete_pytorch.cu | 11 ++--- csrc/torch_bindings.cpp | 2 - vllm/_custom_ops.py | 4 -- .../layers/quantization/utils/quant_utils.py | 6 +-- 12 files changed, 53 insertions(+), 70 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c7e5b36f7990..099b4b1787cab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,7 +242,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ) if (NOT machete_generation_result EQUAL 0) - file(READ ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log log) message(FATAL_ERROR "Machete generation failed." " Result: \"${machete_generation_result}\"" "\nCheck the log for details: " diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 0830742e457fb..c35224218e91c 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -8,7 +8,7 @@ #define HOST_DEVICE_INLINE inline #define DEVICE_INLINE inline #define HOST_INLINE inline -#endif // CUTE_HOST_DEVICE, CUTE_DEVICE +#endif int64_t get_device_attribute(int64_t attribute, int64_t device_id); diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh index 1928fbc45155a..14aa51703b6c5 100644 --- a/csrc/cutlass_extensions/cute_utils.cuh +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -4,6 +4,16 @@ //////////////////////////////////////////////////////////////////// // make_cute_stride +// - instantiates a stride object thats correctly populated base +// on the shape of the tensor and the stride type passed in, +// for example: +// - if s = Stride and shape = {M, N, L} then the stride will be +// constructed as {N, 1}, i.e. Row Major +// - if s = Stride<_1, int> and shape = {M, N, L} then the stride will be +// constructed as {1, M}, i.e. Column Major +// - if s = Stride and shape = {M, N, L} then the stride +// will be constructed as {N, 1, M * N}, i.e. Row Major Batched +// - etc. //////////////////////////////////////////////////////////////////// // diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index ced13e5639b72..a70a2f201f361 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -19,24 +19,6 @@ static inline bool is_column_major(torch::Tensor const tensor) { return tensor.stride(0) == 1 && tensor.stride(1) == tensor.size(0); } -template -T* maybe_data_ptr(c10::optional maybe_tensor, - char const* name) { - if constexpr (std::is_same_v) { - TORCH_CHECK(!maybe_tensor || is_row_major(*maybe_tensor), "Expected ", name, - " to be RowMajor"); - } else if constexpr (std::is_same_v) { - TORCH_CHECK(!maybe_tensor || is_column_major(*maybe_tensor), "Expected ", - name, " to be ColumnMajor"); - } else { - TORCH_CHECK(false, "Unknown Layout"); - } - - return (maybe_tensor == at::nullopt) - ? nullptr - : reinterpret_cast(maybe_tensor->data_ptr()); -} - template T* data_ptr(torch::Tensor const tensor, char const* name) { if constexpr (std::is_same_v) { @@ -51,6 +33,12 @@ T* data_ptr(torch::Tensor const tensor, char const* name) { return reinterpret_cast(tensor.data_ptr()); } +template +T* maybe_data_ptr(c10::optional maybe_tensor, + char const* name) { + return (maybe_tensor) ? data_ptr(*maybe_tensor, name) : nullptr; +} + // // Torch Type to Cutlass Type (equivalent_cutlass_type) // diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh index 14c4d0f7618ef..2a775add9491a 100644 --- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -1,4 +1,4 @@ -// Based off of: +// Based off of: // https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h #pragma once @@ -264,7 +264,6 @@ struct NumericArrayConverter { using source_type_packed_4 = Array; using source_type_packed_2 = Array; - // Not Valid, not supported, only here to satisfy the interface and to avoid // a compile error. ScalarConverter will not actually work until // NumericConverter is implemented @@ -350,7 +349,7 @@ struct NumericArrayConverter { // Not Valid, not supported, only here to satisfy the interface and to avoid // a compile error. ScalarConverter will not actually work until - // NumericConverter is + // NumericConverter is // implemented using ScalarConverter = NumericConverter; @@ -486,8 +485,8 @@ struct NumericArrayConverter { // Not Valid, not supported, only here to satisfy the interface and to avoid // a compile error. ScalarConverter will not actually work until - // NumericConverter is - // implemented + // NumericConverter is + // implemented using ScalarConverter = NumericConverter; diff --git a/csrc/ops.h b/csrc/ops.h index 1652336ae15e4..ae65ef5a47f5d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -85,8 +85,6 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, namespace machete { -std::vector supported_types(); - std::vector supported_schedules( vllm::ScalarTypeTorchPtr const& btype); diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 3f8539f3c0901..0016cf758e2c4 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -27,7 +27,7 @@ #include "../machete_mm_launcher.cuh" namespace machete { -using KernelDispatcher_ = KernelDispatcher< +using GemmDispatcher_ = GemmDispatcher< {{DataTypeTag[type_config.element_a]}}, // ElementA {{DataTypeTag[type_config.element_b]}}, // ElementB {{DataTypeTag[type_config.element_d]}}, // ElementD @@ -36,10 +36,10 @@ {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints {% for s in schedules %}extern torch::Tensor -impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PytorchArguments args); +impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args); {% endfor %} template <> -torch::Tensor KernelDispatcher_::dispatch(PytorchArguments args) { +torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) { [[maybe_unused]] auto M = args.A.size(0); [[maybe_unused]] auto N = args.B.size(1); [[maybe_unused]] auto K = args.A.size(1); @@ -62,7 +62,7 @@ } template <> -std::vector KernelDispatcher_::supported_schedules() { +std::vector GemmDispatcher_::supported_schedules() { return { {% for s in schedules -%} "{{ gen_sch_name(s) }}"{{ ", @@ -103,7 +103,7 @@ }; torch::Tensor -impl_{{type_name}}_sch_{{schedule_name}}(PytorchArguments args) { +impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) { bool with_C = args.C.has_value(), with_scales = args.scales.has_value(), with_zeropoints = args.zeros.has_value(); diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index f77c3a429d717..984bd7bce5841 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -8,7 +8,7 @@ namespace machete { -struct PytorchArguments { +struct PyTorchArguments { torch::Tensor const A; torch::Tensor const B; c10::optional const& scales; @@ -20,26 +20,26 @@ struct PytorchArguments { c10::optional schedule; }; -template -torch::Tensor run_impl(PytorchArguments args) { +template +torch::Tensor run_impl(PyTorchArguments args) { const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); auto device = args.A.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); - using ElementA = typename KernelSpeacialization::ElementA; - using ElementB = typename KernelSpeacialization::ElementB; - using ElementC = typename KernelSpeacialization::ElementC; - using ElementD = typename KernelSpeacialization::ElementD; - using ElementScale = typename KernelSpeacialization::ElementScale; - using ElementZero = typename KernelSpeacialization::ElementZero; + using ElementA = typename KernelSpecialization::ElementA; + using ElementB = typename KernelSpecialization::ElementB; + using ElementC = typename KernelSpecialization::ElementC; + using ElementD = typename KernelSpecialization::ElementD; + using ElementScale = typename KernelSpecialization::ElementScale; + using ElementZero = typename KernelSpecialization::ElementZero; - using LayoutA = typename KernelSpeacialization::LayoutA; - using LayoutB = typename KernelSpeacialization::LayoutB; - using LayoutC = typename KernelSpeacialization::LayoutC; - using LayoutD = typename KernelSpeacialization::LayoutD; - using LayoutScale = typename KernelSpeacialization::LayoutScale; - using LayoutZero = typename KernelSpeacialization::LayoutScale; + using LayoutA = typename KernelSpecialization::LayoutA; + using LayoutB = typename KernelSpecialization::LayoutB; + using LayoutC = typename KernelSpecialization::LayoutC; + using LayoutD = typename KernelSpecialization::LayoutD; + using LayoutScale = typename KernelSpecialization::LayoutScale; + using LayoutZero = typename KernelSpecialization::LayoutScale; int M = args.A.size(0); int N = args.B.size(1); @@ -60,19 +60,19 @@ torch::Tensor run_impl(PytorchArguments args) { auto zeros_ptr = maybe_data_ptr(args.zeros, "zeros"); - auto arguments = KernelSpeacialization::create_arguments( + auto arguments = KernelSpecialization::create_arguments( stream, M, N, K, A_ptr, B_ptr, C_ptr, D_ptr, scales_ptr, zeros_ptr, args.alpha.value_or(1), args.beta.value_or(0), args.group_size.value_or(K)); - TORCH_CHECK(KernelSpeacialization::can_implement(arguments), + TORCH_CHECK(KernelSpecialization::can_implement(arguments), "Machete kernel cannot be run with these arguments"); - size_t workspace_size = KernelSpeacialization::get_workspace_size(arguments); + size_t workspace_size = KernelSpecialization::get_workspace_size(arguments); torch::Tensor workspace = torch::empty( workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device)); - KernelSpeacialization::run(arguments, workspace.mutable_data_ptr(), stream); + KernelSpecialization::run(arguments, workspace.mutable_data_ptr(), stream); return D; }; @@ -80,8 +80,8 @@ torch::Tensor run_impl(PytorchArguments args) { template -struct KernelDispatcher { - static torch::Tensor dispatch(PytorchArguments args); +struct GemmDispatcher { + static torch::Tensor dispatch(PyTorchArguments args); static std::vector supported_schedules(); }; diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index 7f9c70c76ddf9..0f68dfdcd0528 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -36,14 +36,9 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { // Interface // -std::vector supported_types() { - return {c10::make_intrusive(vllm::kU4), - c10::make_intrusive(vllm::kS4)}; -} - std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { return scalar_type_dispatch(*btype, [&](auto BType) { - return KernelDispatcher::supported_schedules(); + return GemmDispatcher::supported_schedules(); }); } @@ -55,7 +50,7 @@ torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B, c10::optional const& C, c10::optional alpha, c10::optional beta, c10::optional schedule) { - auto args = PytorchArguments{.A = A, + auto args = PyTorchArguments{.A = A, .B = B, .scales = scales, .zeros = zeros, @@ -69,7 +64,7 @@ torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B, return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( A.scalar_type(), "machete_gemm", [&] { using ComputeType = equivalent_cutlass_type_t; - return KernelDispatcher::dispatch(args); + return GemmDispatcher::dispatch(args); }); }); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 79b975dca510c..f4c8d406c671b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -137,8 +137,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("machete_supported_schedules", &machete::supported_schedules); ops.impl("machete_supported_schedules", torch::kCPU, &machete::supported_schedules); - ops.def("machete_supported_types", &machete::supported_types); - ops.impl("machete_supported_types", torch::kCPU, &machete::supported_types); ops.def("machete_gemm", &machete::gemm); ops.impl("machete_gemm", torch::kCUDA, &machete::gemm); ops.def("machete_prepack_B", &machete::prepack_B); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e7fe2b3983866..6283f2c6bc4b9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -334,10 +334,6 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # machete -def machete_supported_types() -> List[ScalarType]: - return torch.ops._C.machete_supported_types() - - def machete_supported_schedules(b_type: ScalarType) -> List[str]: return torch.ops._C.machete_supported_schedules(b_type) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 8ce71415e6820..33f24ff5d54d3 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -127,9 +127,9 @@ def quantize_weights(w: torch.Tensor, w_q = torch.clamp(w_q, min_q_val, max_q_val) # Compute ref (dequantized) - # For some kernels (namely Machete) the scales are applied after the scales - # are applied, for this case computing the reference in similar way allows - # us to use tighter error tolerances in our unit tests. + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. if ref_zero_points_after_scales and zero_points: w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s else: