Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 15, 2024
1 parent a152da3 commit e92b26e
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 77 deletions.
24 changes: 16 additions & 8 deletions csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
// Based off of:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h

#pragma once

#include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
#include "cutlass_extensions/cute_utils.cuh"

// this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t
// as well as adds interleaved numeric array converters for specific types.
// (interleaved numeric array converters can be more efficient for subbyte
// types)

namespace cutlass {

// InterleavedNumericArrayConverter is like NumericArrayConverter but also
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
// make subbyte converts more efficient by allowing for efficient extraction
// of subbyte elements from a 32bit register.
template <typename IlvBlkLayout, typename T, typename S, int N,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
class Enable = void>
Expand Down Expand Up @@ -48,7 +56,7 @@ struct InterleavedNumericArrayConverter<
result_type operator()(source_type const& s) const { return convert(s); }
};

// TODO (Lucas): Implement
// TODO (LucasWilkinson): Implement
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>

// ....
Expand All @@ -71,10 +79,10 @@ struct ArrayConverterPacked32Bit {
static constexpr auto src_elems_per_32bit_reg =
32 / cutlass::sizeof_bits_v<S>;

// Maybe not Valid,. ScalarConverter will not actually work unless
// NumericConverter<T, S, Round> is implemented
// but it won't be used since we assert N % 2 == 0, just here for compliance
// with VectorizedConverter
// Maybe not Valid. ScalarConverter will not actually work unless
// NumericConverter<T, S, Round> is implemented. However it won't be used
// anyways since we assert N % 2 == 0, just here for compliance with
// VectorizedConverter.
using ScalarConverter = NumericConverter<T, S>;

template <typename PackedSrc>
Expand Down
44 changes: 0 additions & 44 deletions csrc/math_utils.h

This file was deleted.

1 change: 0 additions & 1 deletion csrc/quantization/machete/machete_mainloop.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "cutlass/detail/collective.hpp"
// clang-format on

#include "math_utils.h"
#include "cutlass_extensions/cute_utils.cuh"

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
8 changes: 8 additions & 0 deletions csrc/quantization/machete/machete_prepack_launcher.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ template <typename PrepackedLayoutB>
torch::Tensor prepack_impl(torch::Tensor const B) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
using ElementB = typename PrepackedLayoutB::ElementB;
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;

auto device = B.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
Expand All @@ -21,6 +22,13 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
auto Bt_packed = B.t();

TORCH_CHECK(
(B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
size<1>(PPBlockShape_NK{}));
TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0,
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));

using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");

Expand Down
5 changes: 3 additions & 2 deletions csrc/quantization/machete/machete_prepacked_layout.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#include <torch/all.h>

// clang-format off
// The cutlass include order
// The cutlass include order matters (annoyingly)

#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
Expand Down Expand Up @@ -53,7 +54,7 @@ struct PrepackedLayoutBTemplate {
void>,
IlvBlkLayout_>;

// TODO (Lucas): compare the performance for other sizes
// TODO (LucasWilkinson): compare the performance for other sizes
// Prepacked block shape, smallest layout atom for loading into registers
// (can contain multiple wgmma instructions worth of data in one block)
using PPBlockShape_NK = Shape<_128, _64>;
Expand Down
148 changes: 126 additions & 22 deletions tests/kernels/test_machete_gemm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for the marlin kernel.
"""Tests for the machete kernel.
Run `pytest tests/kernels/marlin/test_machete_gemm.py`.
Run `pytest tests/kernels/test_machete_gemm.py`.
"""

import math
Expand All @@ -15,6 +15,10 @@
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types

CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

MNK_SHAPES = [
(1, 128, 128),
(1, 512, 1024),
Expand All @@ -23,7 +27,7 @@
(26, 4096, 8192),
(1, 4096, 4096),
(257, 128, 4096),
(257, 4224, 4096),
(257, 4224, 4160),
(257, 4096, 4096),
(64, 4096, 4096),
]
Expand Down Expand Up @@ -75,6 +79,30 @@ def machete_quantize_and_pack(w: torch.Tensor,
return w_ref, w_q_machete, w_s, w_zp


def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor,
wtype: ScalarType, group_size: int,
zero_points: bool):
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
b, wtype, group_size, zero_points)

output_ref = torch.matmul(a, w_ref)

output = ops.machete_gemm(
a=a,
b_q=w_q_packed,
b_type=wtype,
b_scales=w_s,
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
b_group_size=group_size,
)

# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1)
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)


@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("shape",
Expand All @@ -86,18 +114,21 @@ def machete_quantize_and_pack(w: torch.Tensor,
def test_machete_all_schedules(shape, atype: torch.dtype,
wtype_zeropoints: Tuple[ScalarType, bool],
group_size: Optional[int]):
size_m, size_k, size_n = shape
m, n, k = shape
wtype, zero_points = wtype_zeropoints

print(f"MNK = {size_m} {size_n} {size_k}")
if group_size is not None and k % group_size != 0:
return

print(f"MNK = {m} {n} {k}")

# Normalize group_size
if group_size is None:
group_size = size_k
assert group_size <= size_k
group_size = k
assert group_size <= k

a = rand_data((size_m, size_k), atype)
w = rand_data((size_k, size_n), atype)
a = rand_data((m, k), atype)
w = rand_data((k, n), atype)

w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack(
w, wtype, group_size, zero_points)
Expand All @@ -118,7 +149,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if zero_points else min(5e-2 * math.sqrt(size_k), 1)
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\
f"Schedule failed {schedule}"

Expand All @@ -134,35 +165,108 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
def test_machete_heuristic(shape, atype: torch.dtype,
wtype_zeropoints: Tuple[ScalarType, bool],
group_size: Optional[int]):
size_m, size_k, size_n = shape
m, n, k = shape
wtype, zero_points = wtype_zeropoints

print(f"MNK = {size_m} {size_n} {size_k}")
if group_size is not None and k % group_size != 0:
return

# Normalize group_size
if group_size is None:
group_size = size_k
assert group_size <= size_k
group_size = k
assert group_size <= k

a = rand_data((size_m, size_k), atype)
b_weight = rand_data((size_k, size_n), atype)
a = rand_data((m, k), atype)
b = rand_data((k, n), atype)

w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
b_weight, wtype, group_size, zero_points)
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)

output_ref = torch.matmul(a, w_ref)

output = ops.machete_gemm(
a,
# Test working on other devices
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_machete_devices(device: str):
m, n, k = 512, 4096, 4096
wtype = scalar_types.uint4b8
group_size = 128
zero_points = False

print(f"MNK = {m} {n} {k}, device = {device}")

a = rand_data((m, k), torch.float16).to(device)
b = rand_data((k, n), torch.float16).to(device)

machete_gemm_test_helper(a, b, wtype, group_size, zero_points)


# Test working with a subset of A and B
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
def test_machete_subset():
big_m, big_n, big_k = 1024, 1024, 1024
m, n, k = 512, 512, 512
wtype = scalar_types.uint4b8
group_size = 128
zero_points = False

whole_a = rand_data((big_m, big_k), torch.float16)
whole_b = rand_data((big_k, big_n), torch.float16)

a = whole_a[0:m, 0:k]
b = whole_b[0:k, 0:n]

machete_gemm_test_helper(a, b, wtype, group_size, zero_points)


# Test to make sure cuda graphs work
class MacheteLayer(torch.nn.Module):

def __init__(self, **kwargs):
super().__init__()
self.kwargs = kwargs

def forward(self, a):
return ops.machete_gemm(**self.kwargs)


@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
def test_machete_cuda_graph():
m, n, k = 512, 4096, 4096

a = rand_data((m, k), torch.float16)
b = rand_data((k, n), torch.float16)
wtype = scalar_types.uint4b8
group_size = 128
zero_points = False

w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
b, wtype, group_size, zero_points)

# Construct a trivial model with a single layer that calls a machete kernel
model = MacheteLayer(
a=a,
b_q=w_q_packed,
b_type=wtype,
b_scales=w_s,
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
b_group_size=group_size,
)

output_ref = torch.matmul(a, w_ref)

# Run the model with a cuda graph
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
output = model(a)
output.zero_()
g.replay()

# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if zero_points else min(5e-2 * math.sqrt(size_k), 1)
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)

0 comments on commit e92b26e

Please sign in to comment.