Skip to content

Commit

Permalink
[bugfix] Fix static asymmetric quantization case (vllm-project#10334)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniël de Kok <[email protected]>
Signed-off-by: luka <[email protected]>
Co-authored-by: Daniël de Kok <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
2 people authored and mfournioux committed Nov 20, 2024
1 parent 00e913f commit f534b0b
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 15 deletions.
19 changes: 10 additions & 9 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
assert torch_out.min() >= int8_traits.min and torch_out.max(
) <= int8_traits.max

ops_out = torch.empty_like(x, dtype=torch.int8)
scales_out = torch.empty_like(scales, dtype=torch.float32)
azp_out = torch.empty_like(azps, dtype=torch.int32)
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out)
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)

if (not torch.allclose(scales_out, scales)):
print(torch.argmax(torch.abs(scales_out - scales)))
Expand Down Expand Up @@ -119,7 +116,8 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,

out1 = (x / scale_arg).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2, _, _ = scaled_int8_quant(x, scale_arg)
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
assert scale2 is scale_arg

# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
Expand All @@ -145,11 +143,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,

out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")

torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg)
out2, scale2, azp2 = scaled_int8_quant(x,
scale_arg,
azp_arg,
symmetric=False)
assert scale2 is scale_arg
assert azp2 is azp_arg

# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
Expand Down Expand Up @@ -184,6 +186,5 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
val_i8 = int8_traits.max if is_max else int8_traits.min
expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")

out = torch.empty_like(expected)
torch.ops._C.static_scaled_int8_quant(out, x, scale, azp)
out, _, _ = scaled_int8_quant(x, scale, azp, symmetric=False)
torch.testing.assert_close(expected, out, atol=0, rtol=0)
30 changes: 30 additions & 0 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from compressed_tensors.quantization import QuantizationType

from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
Expand Down Expand Up @@ -74,6 +75,35 @@ def zp_valid(zp: Optional[torch.Tensor]):
assert output


@pytest.mark.parametrize(
"model_path",
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
# TODO static & asymmetric
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
example_prompts, model_path,
max_tokens, num_logprobs):
dtype = "bfloat16"

with hf_runner(model_path, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model_path, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


def test_compressed_tensors_no_enforce_eager(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
with vllm_runner(model_path) as llm:
Expand Down
8 changes: 7 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,16 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
azp_adj: torch.Tensor,
azp: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
:param azp_adj: In the per-tensor case, this should include the azp.
Always per-channel.
:param azp: Only set in the per-token case. Per-token if set.
"""
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.numel(
) == b.shape[1] and bias.dtype == out_dtype
assert azp is None or azp.numel() == a.shape[0]

m = a.shape[0]
n = b.shape[1]
Expand Down Expand Up @@ -735,7 +741,7 @@ def scaled_int8_quant(
azp is
None), "azp must only be provided for asymmetric quantization."
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
return output, scale, None
return output, scale, azp

# dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.input_symmetric:
layer.azp_adj = layer.weight.sum(dim=0,
keepdim=True,
dtype=torch.int32)
azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = layer.input_zero_point * azp_adj

layer.azp_adj = azp_adj
else:
layer.azp_adj = None

Expand Down Expand Up @@ -138,7 +142,6 @@ def create_weights(self, layer: torch.nn.Module,

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,16 @@ def apply_int8_linear(
symmetric=symmetric)

if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
static = input_zero_point is not None
azp = None if static else x_zp
return ops.cutlass_scaled_mm_azp(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
azp_adj=azp_adj,
azp=x_zp,
azp=azp,
bias=bias)
return ops.cutlass_scaled_mm(x_q,
weight,
Expand Down

0 comments on commit f534b0b

Please sign in to comment.