Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Nov 14, 2024
1 parent ef43d89 commit 6af8654
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
7 changes: 4 additions & 3 deletions benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,11 @@ def to_torch_dtype(dt):
"int": torch.int,
"float": torch.float,
}[dt]

class ToTorchDtype(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, to_torch_dtype(values))

def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, to_torch_dtype(values))

parser = FlexibleArgumentParser(
description="""
Expand Down
16 changes: 8 additions & 8 deletions csrc/quantization/machete/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@
&& {%if t.b_channel_scale != void -%}
maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}}
{%- else %}!maybe_ch_scales_type{%endif%}
&& {%if t.b_token_scale != void -%}
maybe_tok_scales_type == {{TorchTypeTag[t.b_token_scale]}}
&& {%if t.a_token_scale != void -%}
maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}}
{%- else %}!maybe_tok_scales_type{%endif%}
) {
return mm_dispatch_{{type_sig}}(args);
Expand Down Expand Up @@ -188,7 +188,7 @@
{{DataTypeTag[t.b_group_scale]}}, // GroupScaleT
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
{{DataTypeTag[t.b_token_scale]}}, // TokenScaleT
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
Sch>;
Expand Down Expand Up @@ -259,7 +259,7 @@ class TypeConfig:
b_group_scale: DataType
b_group_zeropoint: DataType
b_channel_scale: DataType
b_token_scale: DataType
a_token_scale: DataType
out: DataType
accumulator: DataType

Expand Down Expand Up @@ -532,7 +532,7 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
b_group_scale=a,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
a_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
Expand All @@ -552,7 +552,7 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
b_group_scale=a,
b_group_zeropoint=a,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
a_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (DataType.u4, DataType.u8)
Expand Down Expand Up @@ -615,7 +615,7 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
b_group_scale=b_group_scale,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.f32,
b_token_scale=DataType.f32,
a_token_scale=DataType.f32,
out=DataType.f16,
accumulator=DataType.s32,
) for b_group_scale in (DataType.f16, DataType.void)),
Expand All @@ -625,7 +625,7 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
b_group_scale=b_group_scale,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.f32,
b_token_scale=DataType.f32,
a_token_scale=DataType.f32,
out=DataType.f16,
accumulator=DataType.f32,
) for b_group_scale in (DataType.f16, DataType.void)),
Expand Down
File renamed without changes.
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,6 @@ def unpack_cols(

orig_device = packed_q_w.device

packed_q_w = packed_q_w.t()

packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)

Expand Down

0 comments on commit 6af8654

Please sign in to comment.