Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Nov 6, 2024
1 parent 17bebb1 commit 565770c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 29 deletions.
11 changes: 7 additions & 4 deletions csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ struct InterleavedNumericArrayConverter {
static result_type convert(source_type const& source) {
if (cute::elect_one_sync()) {
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
printf(" %s <= %s (N = %d, IlvBlkLayout = void)\n", nameof_v<T>,
nameof_v<S>, N);
printf(
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
nameof_v<T>, nameof_v<S>, N);
} else {
printf(" %s <= %s (N = %d, size(IlvBlkLayout{}) = %d)\n", nameof_v<T>,
nameof_v<S>, N, size(IlvBlkLayout{}));
printf(
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
"implemented\n",
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
}
__brkpt();
}
Expand Down
52 changes: 27 additions & 25 deletions csrc/quantization/machete/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
TORCH_CHECK_NOT_IMPLEMENTED(
false, "machete_mm(..) is not implemented for "
"a_type=", args.A.scalar_type(),
", b_type=", args.b_type.str(),
", b_type=", args.btype.str(),
", out_type=", out_type,
", with_group_scale_type=", maybe_g_scales_type
? toString(*maybe_g_scales_type) : "None",
Expand Down Expand Up @@ -525,40 +525,42 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):

impl_configs = []

GPTQ_kernel_types = list((TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for a in (DataType.f16, DataType.bf16)))
GPTQ_kernel_type_configs = list(
TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for a in (DataType.f16, DataType.bf16))

impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(GPTQ_kernel_types,
for x in zip(GPTQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic))
]

AWQ_kernel_types = list((TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=a,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (DataType.u4, DataType.u8)
for a in (DataType.f16, DataType.bf16)))
AWQ_kernel_type_configs = list(
TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=a,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (DataType.u4, DataType.u8)
for a in (DataType.f16, DataType.bf16))

impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(AWQ_kernel_types,
for x in zip(AWQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic))
]
Expand Down

0 comments on commit 565770c

Please sign in to comment.