Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 9, 2024
1 parent 50508ea commit 8172f4e
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 70 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
)

if (NOT machete_generation_result EQUAL 0)
file(READ ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log log)
message(FATAL_ERROR "Machete generation failed."
" Result: \"${machete_generation_result}\""
"\nCheck the log for details: "
Expand Down
2 changes: 1 addition & 1 deletion csrc/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline
#define HOST_INLINE inline
#endif // CUTE_HOST_DEVICE, CUTE_DEVICE
#endif

int64_t get_device_attribute(int64_t attribute, int64_t device_id);

Expand Down
10 changes: 10 additions & 0 deletions csrc/cutlass_extensions/cute_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@

////////////////////////////////////////////////////////////////////
// make_cute_stride
// - instantiates a stride object thats correctly populated base
// on the shape of the tensor and the stride type passed in,
// for example:
// - if s = Stride<int, _1> and shape = {M, N, L} then the stride will be
// constructed as {N, 1}, i.e. Row Major
// - if s = Stride<_1, int> and shape = {M, N, L} then the stride will be
// constructed as {1, M}, i.e. Column Major
// - if s = Stride<int, _1, int64_t> and shape = {M, N, L} then the stride
// will be constructed as {N, 1, M * N}, i.e. Row Major Batched
// - etc.
////////////////////////////////////////////////////////////////////

//
Expand Down
24 changes: 6 additions & 18 deletions csrc/cutlass_extensions/torch_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,6 @@ static inline bool is_column_major(torch::Tensor const tensor) {
return tensor.stride(0) == 1 && tensor.stride(1) == tensor.size(0);
}

template <typename T, typename Layout = RowMajor>
T* maybe_data_ptr(c10::optional<torch::Tensor const> maybe_tensor,
char const* name) {
if constexpr (std::is_same_v<Layout, RowMajor>) {
TORCH_CHECK(!maybe_tensor || is_row_major(*maybe_tensor), "Expected ", name,
" to be RowMajor");
} else if constexpr (std::is_same_v<Layout, ColumnMajor>) {
TORCH_CHECK(!maybe_tensor || is_column_major(*maybe_tensor), "Expected ",
name, " to be ColumnMajor");
} else {
TORCH_CHECK(false, "Unknown Layout");
}

return (maybe_tensor == at::nullopt)
? nullptr
: reinterpret_cast<T*>(maybe_tensor->data_ptr());
}

template <typename T, typename Layout = RowMajor>
T* data_ptr(torch::Tensor const tensor, char const* name) {
if constexpr (std::is_same_v<Layout, RowMajor>) {
Expand All @@ -51,6 +33,12 @@ T* data_ptr(torch::Tensor const tensor, char const* name) {
return reinterpret_cast<T*>(tensor.data_ptr());
}

template <typename T, typename Layout = RowMajor>
T* maybe_data_ptr(c10::optional<torch::Tensor const> maybe_tensor,
char const* name) {
return (maybe_tensor) ? data_ptr<T, Layout>(*maybe_tensor, name) : nullptr;
}

//
// Torch Type to Cutlass Type (equivalent_cutlass_type)
//
Expand Down
9 changes: 4 additions & 5 deletions csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Based off of:
// Based off of:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h

#pragma once
Expand Down Expand Up @@ -264,7 +264,6 @@ struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
using source_type_packed_4 = Array<vllm_uint8b128_t, 4>;
using source_type_packed_2 = Array<vllm_uint8b128_t, 2>;


// Not Valid, not supported, only here to satisfy the interface and to avoid
// a compile error. ScalarConverter will not actually work until
// NumericConverter<cutlass::float, vllm_uint8b128_t, Round> is implemented
Expand Down Expand Up @@ -350,7 +349,7 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {

// Not Valid, not supported, only here to satisfy the interface and to avoid
// a compile error. ScalarConverter will not actually work until
// NumericConverter<cutlass::bfloat16_t, vllm_uint4b8_t, Round> is
// NumericConverter<cutlass::bfloat16_t, vllm_uint4b8_t, Round> is
// implemented
using ScalarConverter =
NumericConverter<cutlass::bfloat16_t, vllm_uint4b8_t, Round>;
Expand Down Expand Up @@ -486,8 +485,8 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {

// Not Valid, not supported, only here to satisfy the interface and to avoid
// a compile error. ScalarConverter will not actually work until
// NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round> is
// implemented
// NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round> is
// implemented
using ScalarConverter =
NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round>;

Expand Down
2 changes: 0 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,

namespace machete {

std::vector<vllm::ScalarTypeTorchPtr> supported_types();

std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);

Expand Down
10 changes: 5 additions & 5 deletions csrc/quantization/machete/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "../machete_mm_launcher.cuh"
namespace machete {
using KernelDispatcher_ = KernelDispatcher<
using GemmDispatcher_ = GemmDispatcher<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
Expand All @@ -36,10 +36,10 @@
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
{% for s in schedules %}extern torch::Tensor
impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PytorchArguments args);
impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
{% endfor %}
template <>
torch::Tensor KernelDispatcher_::dispatch(PytorchArguments args) {
torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
[[maybe_unused]] auto M = args.A.size(0);
[[maybe_unused]] auto N = args.B.size(1);
[[maybe_unused]] auto K = args.A.size(1);
Expand All @@ -62,7 +62,7 @@
}
template <>
std::vector<std::string> KernelDispatcher_::supported_schedules() {
std::vector<std::string> GemmDispatcher_::supported_schedules() {
return {
{% for s in schedules -%}
"{{ gen_sch_name(s) }}"{{ ",
Expand Down Expand Up @@ -103,7 +103,7 @@
};
torch::Tensor
impl_{{type_name}}_sch_{{schedule_name}}(PytorchArguments args) {
impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
with_zeropoints = args.zeros.has_value();
Expand Down
42 changes: 21 additions & 21 deletions csrc/quantization/machete/machete_mm_launcher.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace machete {

struct PytorchArguments {
struct PyTorchArguments {
torch::Tensor const A;
torch::Tensor const B;
c10::optional<torch::Tensor> const& scales;
Expand All @@ -20,26 +20,26 @@ struct PytorchArguments {
c10::optional<std::string> schedule;
};

template <typename KernelSpeacialization>
torch::Tensor run_impl(PytorchArguments args) {
template <typename KernelSpecialization>
torch::Tensor run_impl(PyTorchArguments args) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));

auto device = args.A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());

using ElementA = typename KernelSpeacialization::ElementA;
using ElementB = typename KernelSpeacialization::ElementB;
using ElementC = typename KernelSpeacialization::ElementC;
using ElementD = typename KernelSpeacialization::ElementD;
using ElementScale = typename KernelSpeacialization::ElementScale;
using ElementZero = typename KernelSpeacialization::ElementZero;
using ElementA = typename KernelSpecialization::ElementA;
using ElementB = typename KernelSpecialization::ElementB;
using ElementC = typename KernelSpecialization::ElementC;
using ElementD = typename KernelSpecialization::ElementD;
using ElementScale = typename KernelSpecialization::ElementScale;
using ElementZero = typename KernelSpecialization::ElementZero;

using LayoutA = typename KernelSpeacialization::LayoutA;
using LayoutB = typename KernelSpeacialization::LayoutB;
using LayoutC = typename KernelSpeacialization::LayoutC;
using LayoutD = typename KernelSpeacialization::LayoutD;
using LayoutScale = typename KernelSpeacialization::LayoutScale;
using LayoutZero = typename KernelSpeacialization::LayoutScale;
using LayoutA = typename KernelSpecialization::LayoutA;
using LayoutB = typename KernelSpecialization::LayoutB;
using LayoutC = typename KernelSpecialization::LayoutC;
using LayoutD = typename KernelSpecialization::LayoutD;
using LayoutScale = typename KernelSpecialization::LayoutScale;
using LayoutZero = typename KernelSpecialization::LayoutScale;

int M = args.A.size(0);
int N = args.B.size(1);
Expand All @@ -60,28 +60,28 @@ torch::Tensor run_impl(PytorchArguments args) {
auto zeros_ptr =
maybe_data_ptr<ElementZero const, LayoutZero>(args.zeros, "zeros");

auto arguments = KernelSpeacialization::create_arguments(
auto arguments = KernelSpecialization::create_arguments(
stream, M, N, K, A_ptr, B_ptr, C_ptr, D_ptr, scales_ptr, zeros_ptr,
args.alpha.value_or(1), args.beta.value_or(0),
args.group_size.value_or(K));

TORCH_CHECK(KernelSpeacialization::can_implement(arguments),
TORCH_CHECK(KernelSpecialization::can_implement(arguments),
"Machete kernel cannot be run with these arguments");

size_t workspace_size = KernelSpeacialization::get_workspace_size(arguments);
size_t workspace_size = KernelSpecialization::get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(
workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device));

KernelSpeacialization::run(arguments, workspace.mutable_data_ptr(), stream);
KernelSpecialization::run(arguments, workspace.mutable_data_ptr(), stream);

return D;
};

template <typename ElementA, typename ElementB, typename ElementD = ElementA,
typename AccumulatorT = float, typename ScaleT = ElementA,
typename ZeroT = ElementA>
struct KernelDispatcher {
static torch::Tensor dispatch(PytorchArguments args);
struct GemmDispatcher {
static torch::Tensor dispatch(PyTorchArguments args);
static std::vector<std::string> supported_schedules();
};

Expand Down
11 changes: 3 additions & 8 deletions csrc/quantization/machete/machete_pytorch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,9 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
// Interface
//

std::vector<ScalarTypeTorchPtr> supported_types() {
return {c10::make_intrusive<ScalarTypeTorch>(vllm::kU4),
c10::make_intrusive<ScalarTypeTorch>(vllm::kS4)};
}

std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
return scalar_type_dispatch(*btype, [&](auto BType) {
return KernelDispatcher<half_t, decltype(BType)>::supported_schedules();
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
});
}

Expand All @@ -55,7 +50,7 @@ torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
auto args = PytorchArguments{.A = A,
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
.zeros = zeros,
Expand All @@ -69,7 +64,7 @@ torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B,
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
A.scalar_type(), "machete_gemm", [&] {
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
return KernelDispatcher<ComputeType, decltype(BType)>::dispatch(args);
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
});
});
}
Expand Down
2 changes: 0 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("machete_supported_schedules", &machete::supported_schedules);
ops.impl("machete_supported_schedules", torch::kCPU,
&machete::supported_schedules);
ops.def("machete_supported_types", &machete::supported_types);
ops.impl("machete_supported_types", torch::kCPU, &machete::supported_types);
ops.def("machete_gemm", &machete::gemm);
ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
ops.def("machete_prepack_B", &machete::prepack_B);
Expand Down
4 changes: 0 additions & 4 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,6 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,


# machete
def machete_supported_types() -> List[ScalarType]:
return torch.ops._C.machete_supported_types()


def machete_supported_schedules(b_type: ScalarType) -> List[str]:
return torch.ops._C.machete_supported_schedules(b_type)

Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def quantize_weights(w: torch.Tensor,
w_q = torch.clamp(w_q, min_q_val, max_q_val)

# Compute ref (dequantized)
# For some kernels (namely Machete) the scales are applied after the scales
# are applied, for this case computing the reference in similar way allows
# us to use tighter error tolerances in our unit tests.
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and zero_points:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
Expand Down

0 comments on commit 8172f4e

Please sign in to comment.