Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Fix static asymmetric quantization case #10334

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
assert torch_out.min() >= int8_traits.min and torch_out.max(
) <= int8_traits.max

ops_out = torch.empty_like(x, dtype=torch.int8)
scales_out = torch.empty_like(scales, dtype=torch.float32)
azp_out = torch.empty_like(azps, dtype=torch.int32)
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out)
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)

if (not torch.allclose(scales_out, scales)):
print(torch.argmax(torch.abs(scales_out - scales)))
Expand Down Expand Up @@ -119,7 +116,8 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,

out1 = (x / scale_arg).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2, _, _ = scaled_int8_quant(x, scale_arg)
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
assert scale2 is scale_arg

# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
Expand All @@ -145,11 +143,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,

out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")

torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg)
out2, scale2, azp2 = scaled_int8_quant(x,
scale_arg,
azp_arg,
symmetric=False)
assert scale2 is scale_arg
assert azp2 is azp_arg

# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
Expand Down Expand Up @@ -184,6 +186,5 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
val_i8 = int8_traits.max if is_max else int8_traits.min
expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")

out = torch.empty_like(expected)
torch.ops._C.static_scaled_int8_quant(out, x, scale, azp)
out, _, _ = scaled_int8_quant(x, scale, azp, symmetric=False)
torch.testing.assert_close(expected, out, atol=0, rtol=0)
8 changes: 7 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,16 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
azp_adj: torch.Tensor,
azp: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
:param azp_adj: In the per-tensor case, this should include the azp.
Always per-channel.
:param azp: Only set in the per-token case. Per-token if set.
"""
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.numel(
) == b.shape[1] and bias.dtype == out_dtype
assert azp is None or azp.numel() == a.shape[0]

m = a.shape[0]
n = b.shape[1]
Expand Down Expand Up @@ -735,7 +741,7 @@ def scaled_int8_quant(
azp is
None), "azp must only be provided for asymmetric quantization."
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
return output, scale, None
return output, scale, azp

# dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.input_symmetric:
layer.azp_adj = layer.weight.sum(dim=0,
keepdim=True,
dtype=torch.int32)
azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = layer.input_zero_point * azp_adj

layer.azp_adj = azp_adj
else:
layer.azp_adj = None

Expand Down Expand Up @@ -138,7 +142,6 @@ def create_weights(self, layer: torch.nn.Module,

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,16 @@ def apply_int8_linear(
symmetric=symmetric)

if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
static = input_zero_point is not None
azp = None if static else x_zp
return ops.cutlass_scaled_mm_azp(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
azp_adj=azp_adj,
azp=x_zp,
azp=azp,
bias=bias)
return ops.cutlass_scaled_mm(x_q,
weight,
Expand Down