diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh index 4ab75dd081a5b..7561b2505b10e 100644 --- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -1,14 +1,22 @@ -// Based off of: -// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h - #pragma once #include "cutlass/numeric_conversion.h" #include "cutlass_extensions/vllm_custom_types.cuh" #include "cutlass_extensions/cute_utils.cuh" +// this file extends: +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h +// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t +// as well as adds interleaved numeric array converters for specific types. +// (interleaved numeric array converters can be more efficient for subbyte +// types) + namespace cutlass { +// InterleavedNumericArrayConverter is like NumericArrayConverter but also +// deinterleaves converted elements based on IlvBlkLayout, interleaving can +// make subbyte converts more efficient by allowing for efficient extraction +// of subbyte elements from a 32bit register. template @@ -48,7 +56,7 @@ struct InterleavedNumericArrayConverter< result_type operator()(source_type const& s) const { return convert(s); } }; -// TODO (Lucas): Implement +// TODO (LucasWilkinson): Implement // for Array <= Array // .... @@ -71,10 +79,10 @@ struct ArrayConverterPacked32Bit { static constexpr auto src_elems_per_32bit_reg = 32 / cutlass::sizeof_bits_v; - // Maybe not Valid,. ScalarConverter will not actually work unless - // NumericConverter is implemented - // but it won't be used since we assert N % 2 == 0, just here for compliance - // with VectorizedConverter + // Maybe not Valid. ScalarConverter will not actually work unless + // NumericConverter is implemented. However it won't be used + // anyways since we assert N % 2 == 0, just here for compliance with + // VectorizedConverter. using ScalarConverter = NumericConverter; template diff --git a/csrc/math_utils.h b/csrc/math_utils.h deleted file mode 100644 index 10668faf59480..0000000000000 --- a/csrc/math_utils.h +++ /dev/null @@ -1,44 +0,0 @@ -#pragma once - -#include "cuda_utils.h" - -#include -#include - -#include -#include - -template -HOST_DEVICE_INLINE constexpr auto div_ceil(T1 a, T2 b) { - return (a + b - 1) / b; -} - -template -HOST_DEVICE_INLINE constexpr auto round_up(T1 a, T2 b) { - return div_ceil(a, b) * b; -} - -template -HOST_DEVICE_INLINE constexpr auto round_down(T1 a, T2 b) { - return (a / b) * b; -} - -template -inline std::enable_if_t, bool> not_zero(T value) { - return value != 0; -} - -template -inline std::enable_if_t || - std::is_same_v || - std::is_same_v, - bool> -not_zero(T value) { - using std::fpclassify; - return fpclassify(value) != FP_ZERO; -} - -template -bool is_zero(T value) { - return !not_zero(value); -} diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index f11105c041f70..4ea7e4e631291 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -25,7 +25,6 @@ #include "cutlass/detail/collective.hpp" // clang-format on -#include "math_utils.h" #include "cutlass_extensions/cute_utils.cuh" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index 0531b02fef6a7..686dd68bd52bb 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -9,6 +9,7 @@ template torch::Tensor prepack_impl(torch::Tensor const B) { const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); using ElementB = typename PrepackedLayoutB::ElementB; + using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK; auto device = B.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); @@ -21,6 +22,13 @@ torch::Tensor prepack_impl(torch::Tensor const B) { // match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L) auto Bt_packed = B.t(); + TORCH_CHECK( + (B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0, + "B.shape[0] (in terms of unpacked elements) must be a multiple of ", + size<1>(PPBlockShape_NK{})); + TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0, + "B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{})); + using StrideB = cutlass::detail::TagToStrideB_t; auto const l_Bt_packed = make_cute_layout(Bt_packed, "B"); diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 7e03f25f904ef..b307341f6f16c 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -5,7 +5,8 @@ #include // clang-format off -// The cutlass include order +// The cutlass include order matters (annoyingly) + #include "cutlass/cutlass.h" #include "cute/tensor.hpp" @@ -53,7 +54,7 @@ struct PrepackedLayoutBTemplate { void>, IlvBlkLayout_>; - // TODO (Lucas): compare the performance for other sizes + // TODO (LucasWilkinson): compare the performance for other sizes // Prepacked block shape, smallest layout atom for loading into registers // (can contain multiple wgmma instructions worth of data in one block) using PPBlockShape_NK = Shape<_128, _64>; diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index 5a792b7cca4b1..dadf594409535 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -1,6 +1,6 @@ -"""Tests for the marlin kernel. +"""Tests for the machete kernel. -Run `pytest tests/kernels/marlin/test_machete_gemm.py`. +Run `pytest tests/kernels/test_machete_gemm.py`. """ import math @@ -15,6 +15,10 @@ from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + MNK_SHAPES = [ (1, 128, 128), (1, 512, 1024), @@ -23,7 +27,7 @@ (26, 4096, 8192), (1, 4096, 4096), (257, 128, 4096), - (257, 4224, 4096), + (257, 4224, 4160), (257, 4096, 4096), (64, 4096, 4096), ] @@ -75,6 +79,30 @@ def machete_quantize_and_pack(w: torch.Tensor, return w_ref, w_q_machete, w_s, w_zp +def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor, + wtype: ScalarType, group_size: int, + zero_points: bool): + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + b, wtype, group_size, zero_points) + + output_ref = torch.matmul(a, w_ref) + + output = ops.machete_gemm( + a=a, + b_q=w_q_packed, + b_type=wtype, + b_scales=w_s, + b_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) + + @pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type.") @pytest.mark.parametrize("shape", @@ -86,18 +114,21 @@ def machete_quantize_and_pack(w: torch.Tensor, def test_machete_all_schedules(shape, atype: torch.dtype, wtype_zeropoints: Tuple[ScalarType, bool], group_size: Optional[int]): - size_m, size_k, size_n = shape + m, n, k = shape wtype, zero_points = wtype_zeropoints - print(f"MNK = {size_m} {size_n} {size_k}") + if group_size is not None and k % group_size != 0: + return + + print(f"MNK = {m} {n} {k}") # Normalize group_size if group_size is None: - group_size = size_k - assert group_size <= size_k + group_size = k + assert group_size <= k - a = rand_data((size_m, size_k), atype) - w = rand_data((size_k, size_n), atype) + a = rand_data((m, k), atype) + w = rand_data((k, n), atype) w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack( w, wtype, group_size, zero_points) @@ -118,7 +149,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype, # Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol when we have zeropoints since the way machete applies # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(size_k), 1) + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\ f"Schedule failed {schedule}" @@ -134,26 +165,88 @@ def test_machete_all_schedules(shape, atype: torch.dtype, def test_machete_heuristic(shape, atype: torch.dtype, wtype_zeropoints: Tuple[ScalarType, bool], group_size: Optional[int]): - size_m, size_k, size_n = shape + m, n, k = shape wtype, zero_points = wtype_zeropoints - print(f"MNK = {size_m} {size_n} {size_k}") + if group_size is not None and k % group_size != 0: + return # Normalize group_size if group_size is None: - group_size = size_k - assert group_size <= size_k + group_size = k + assert group_size <= k - a = rand_data((size_m, size_k), atype) - b_weight = rand_data((size_k, size_n), atype) + a = rand_data((m, k), atype) + b = rand_data((k, n), atype) - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b_weight, wtype, group_size, zero_points) + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - output_ref = torch.matmul(a, w_ref) - output = ops.machete_gemm( - a, +# Test working on other devices +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_machete_devices(device: str): + m, n, k = 512, 4096, 4096 + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + print(f"MNK = {m} {n} {k}, device = {device}") + + a = rand_data((m, k), torch.float16).to(device) + b = rand_data((k, n), torch.float16).to(device) + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test working with a subset of A and B +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_subset(): + big_m, big_n, big_k = 1024, 1024, 1024 + m, n, k = 512, 512, 512 + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + whole_a = rand_data((big_m, big_k), torch.float16) + whole_b = rand_data((big_k, big_n), torch.float16) + + a = whole_a[0:m, 0:k] + b = whole_b[0:k, 0:n] + + machete_gemm_test_helper(a, b, wtype, group_size, zero_points) + + +# Test to make sure cuda graphs work +class MacheteLayer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.machete_gemm(**self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = rand_data((m, k), torch.float16) + b = rand_data((k, n), torch.float16) + wtype = scalar_types.uint4b8 + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + b, wtype, group_size, zero_points) + + # Construct a trivial model with a single layer that calls a machete kernel + model = MacheteLayer( + a=a, b_q=w_q_packed, b_type=wtype, b_scales=w_s, @@ -161,8 +254,19 @@ def test_machete_heuristic(shape, atype: torch.dtype, b_group_size=group_size, ) + output_ref = torch.matmul(a, w_ref) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + output.zero_() + g.replay() + # Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol when we have zeropoints since the way machete applies # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(size_k), 1) + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)