Skip to content

Commit

Permalink
Use new dynamic ops for fusion, tolerance has to be higher.
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Nov 21, 2024
1 parent 466fd53 commit 5f9a3be
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 16 deletions.
2 changes: 1 addition & 1 deletion benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Quant
torch_out, _ = ops.scaled_int8_quant(torch_out)
torch_out, _, _ = ops.scaled_int8_quant(torch_out)


def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Fused Layernorm + Quant kernels
ops.def(
"rms_norm_dynamic_per_token_quant(Tensor! out, Tensor input, "
"rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor! scales, float epsilon, "
"Tensor? scale_ub, Tensor!? residual) -> ()");
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
Expand Down
5 changes: 3 additions & 2 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
model2 = torch.compile(model, backend=backend)
result2 = model2(x)

# Check that it gives the same answer
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)
# Check that it gives the same answer, higher tol for dynamic
ATOL, RTOL = (1e-3, 1e-3) if static else (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
Expand Down
29 changes: 23 additions & 6 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
supports_moe_ops = False
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401

supports_moe_ops = True

# neuron has torch version that doesn't even have impl_abstract
Expand Down Expand Up @@ -241,7 +242,6 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
block_table_bound: torch.Tensor) -> None:

return torch.ops._C.advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables,
Expand All @@ -258,7 +258,6 @@ def rms_norm_dynamic_per_token_quant(
scale_ub: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:

output = torch.empty_like(input, dtype=quant_dtype)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
Expand All @@ -270,6 +269,24 @@ def rms_norm_dynamic_per_token_quant(
return output, scales


# TODO is this necessary?
@register_fake("_C::rms_norm_dynamic_per_token_quant")
def _rms_norm_dynamic_per_token_quant_fake(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
scale_ub: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=quant_dtype)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)

return output, scales


# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
Expand Down Expand Up @@ -723,7 +740,7 @@ def scaled_fp8_quant(
shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz \
if current_platform.is_rocm() else torch.float8_e4m3fn
if current_platform.is_rocm() else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
Expand Down Expand Up @@ -1006,9 +1023,9 @@ def register_graph_buffers(fa: int, handles: List[List[int]],
# the case when users use `import __annotations__` to turn type
# hints into strings.
if isinstance(v, fn_type) \
and v.__code__.co_filename == __file__ \
and any(arg is torch.Tensor or arg == "torch.Tensor"
for arg in v.__annotations__.values()):
and v.__code__.co_filename == __file__ \
and any(arg is torch.Tensor or arg == "torch.Tensor"
for arg in v.__annotations__.values()):
names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
Expand Down
9 changes: 4 additions & 5 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
def rms_norm_dynamic_fp8_quant(result: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, scale: torch.Tensor,
epsilon: float) -> None:
result_rms = torch.empty_like(input)
torch.ops._C.rms_norm(result_rms, input, weight, epsilon)
torch.ops._C.dynamic_scaled_fp8_quant(result, result_rms, scale)
# Last two are scale_ub, residual
torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, None)


@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant")
Expand All @@ -41,8 +40,8 @@ def fused_add_rms_norm_dynamic_fp8_quant(result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
torch.ops._C.dynamic_scaled_fp8_quant(result, input, scale)
# Last two are scale_ub, residual
torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, residual)


@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant")
Expand Down
3 changes: 2 additions & 1 deletion vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Optional

import vllm.envs as envs
from vllm.utils import print_warning_once

if TYPE_CHECKING:
from vllm.config import VllmConfig
Expand Down Expand Up @@ -93,7 +94,7 @@ def get_current_vllm_config() -> "VllmConfig":
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
print_warning_once("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config

0 comments on commit 5f9a3be

Please sign in to comment.