diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 12c578db0893c..761eb95c423fc 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -86,10 +86,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, assert torch_out.min() >= int8_traits.min and torch_out.max( ) <= int8_traits.max - ops_out = torch.empty_like(x, dtype=torch.int8) - scales_out = torch.empty_like(scales, dtype=torch.float32) - azp_out = torch.empty_like(azps, dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out) + ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False) if (not torch.allclose(scales_out, scales)): print(torch.argmax(torch.abs(scales_out - scales))) @@ -119,7 +116,8 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, out1 = (x / scale_arg).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) - out2, _, _ = scaled_int8_quant(x, scale_arg) + out2, scale2, _ = scaled_int8_quant(x, scale_arg) + assert scale2 is scale_arg # big atol to account for rounding errors torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) @@ -145,11 +143,15 @@ def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, out1 = ((x / scale).round() + azp).clamp(int8_traits.min, int8_traits.max).to(torch.int8) - out2 = torch.empty_like(x, dtype=torch.int8) scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda") - torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg) + out2, scale2, azp2 = scaled_int8_quant(x, + scale_arg, + azp_arg, + symmetric=False) + assert scale2 is scale_arg + assert azp2 is azp_arg # big atol to account for rounding errors torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) @@ -184,6 +186,5 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None: val_i8 = int8_traits.max if is_max else int8_traits.min expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda") - out = torch.empty_like(expected) - torch.ops._C.static_scaled_int8_quant(out, x, scale, azp) + out, _, _ = scaled_int8_quant(x, scale, azp, symmetric=False) torch.testing.assert_close(expected, out, atol=0, rtol=0) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 03097569b2b3b..26add5bf6d90d 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -8,6 +8,7 @@ import torch from compressed_tensors.quantization import QuantizationType +from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, @@ -74,6 +75,35 @@ def zp_valid(zp: Optional[torch.Tensor]): assert output +@pytest.mark.parametrize( + "model_path", + [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8" + # TODO static & asymmetric + ]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner, + example_prompts, model_path, + max_tokens, num_logprobs): + dtype = "bfloat16" + + with hf_runner(model_path, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_path, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + def test_compressed_tensors_no_enforce_eager(vllm_runner): model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" with vllm_runner(model_path) as llm: diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 8f331a27a20de..b276b8fc25473 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -510,10 +510,16 @@ def cutlass_scaled_mm_azp(a: torch.Tensor, azp_adj: torch.Tensor, azp: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert bias is None or bias.numel( ) == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] m = a.shape[0] n = b.shape[1] @@ -735,7 +741,7 @@ def scaled_int8_quant( azp is None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, None + return output, scale, azp # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 15d9cdbcbb86b..6cbc58d61e970 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -82,9 +82,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md if not self.input_symmetric: - layer.azp_adj = layer.weight.sum(dim=0, - keepdim=True, - dtype=torch.int32) + azp_adj = layer.weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if self.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = layer.input_zero_point * azp_adj + + layer.azp_adj = azp_adj else: layer.azp_adj = None @@ -138,7 +142,6 @@ def create_weights(self, layer: torch.nn.Module, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - return apply_int8_linear(input=x, weight=layer.weight, weight_scale=layer.weight_scale, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index ec73533126ab6..4037bcb963b25 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -211,13 +211,16 @@ def apply_int8_linear( symmetric=symmetric) if x_zp is not None: + # Currently, static is always per-tensor and dynamic is per-token + static = input_zero_point is not None + azp = None if static else x_zp return ops.cutlass_scaled_mm_azp(x_q, weight, scale_a=x_scale, scale_b=weight_scale, out_dtype=input.dtype, azp_adj=azp_adj, - azp=x_zp, + azp=azp, bias=bias) return ops.cutlass_scaled_mm(x_q, weight,