Skip to content

Commit

Permalink
In progress dynamic fusion debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
ProExpertProg committed Nov 21, 2024
1 parent 5f9a3be commit 651ebdc
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 78 deletions.
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Fused Layernorm + Quant kernels
ops.def(
"rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor! scales, float epsilon, "
"Tensor weight, Tensor! scale, float epsilon, "
"Tensor? scale_ub, Tensor!? residual) -> ()");
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
&rms_norm_dynamic_per_token_quant);
Expand Down
20 changes: 13 additions & 7 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,14 @@ def forward(self, x):
reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
torch.set_default_dtype(dtype)
torch.manual_seed(1)

# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
config = CompilationConfig.PassConfig(
enable_fusion=True,
enable_reshape=True,
dump_graph_stages=["before_fusion", "after_fusion"])
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)

Expand All @@ -73,8 +76,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
result2 = model2(x)

# Check that it gives the same answer, higher tol for dynamic
ATOL, RTOL = (1e-3, 1e-3) if static else (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
ATOL, RTOL = (1e-3, 1e-3) if static else (2e-2, 2e-2)
torch.testing.assert_close(result.to(dtype=torch.float32),
result2.to(dtype=torch.float32),
atol=ATOL,
rtol=RTOL)

# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
Expand All @@ -85,8 +91,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default # noqa: E501
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
else:
rms_quant = torch.ops._C.rms_norm_dynamic_fp8_quant.default
add_rms_quant = torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default # noqa: E501
rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default
add_rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default # noqa: E501
fp8_quant = torch.ops._C.dynamic_scaled_fp8_quant.default

# In pre-nodes, fp8 quant should be present and fused kernels should not
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/test_fused_quant_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [1, 2, 3, 4, 16, 64, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
ADD_RESIDUAL = [False, True]
Expand Down
21 changes: 8 additions & 13 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,19 +272,14 @@ def rms_norm_dynamic_per_token_quant(
# TODO is this necessary?
@register_fake("_C::rms_norm_dynamic_per_token_quant")
def _rms_norm_dynamic_per_token_quant_fake(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
scale_ub: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=quant_dtype)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)

return output, scales
output: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scales: torch.Tensor,
epsilon: float,
scale_ub: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> None:
return None


# quantization ops
Expand Down
8 changes: 2 additions & 6 deletions vllm/compilation/fix_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,15 @@ def __call__(self, graph: torch.fx.Graph):
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'residual', 3: 'scale'}
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
self.defunctionalize(graph, node, mutated_args)

elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default
]:
mutated_args = {1: 'result'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_fp8_quant.default:
mutated_args = {1: 'result', 2: 'scale'}
self.defunctionalize(graph, node, mutated_args)

elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: 'out'}
Expand Down
70 changes: 21 additions & 49 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,6 @@
logger = init_logger(__name__)


# TODO temp
@torch.library.custom_op("_C::rms_norm_dynamic_fp8_quant",
mutates_args=("result", "scale"))
def rms_norm_dynamic_fp8_quant(result: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, scale: torch.Tensor,
epsilon: float) -> None:
# Last two are scale_ub, residual
torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, None)


@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant")
def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, epsilon: float):
return None


@torch.library.custom_op("_C::fused_add_rms_norm_dynamic_fp8_quant",
mutates_args=("result", "residual", "scale"))
def fused_add_rms_norm_dynamic_fp8_quant(result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float) -> None:
# Last two are scale_ub, residual
torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, residual)


@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant")
def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, epsilon: float):
return None


def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")

Expand Down Expand Up @@ -372,12 +338,14 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(
torch.ops._C.rms_norm_static_fp8_quant.default,
torch.ops._C.rms_norm_dynamic_per_token_quant.default,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon)
epsilon=self.epsilon,
scale_ub=None,
residual=None)

# result, scale
return at[1], at[2]
Expand Down Expand Up @@ -413,18 +381,20 @@ def process(self):
# The auto_fn node returns a tuple of (None, result, scale).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, ...) # noqa
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
with self.inserting_after_match():
kwargs = self.match.kwargs.copy()

# Scalars cannot be inputs to the pattern
kwargs["epsilon"] = rms_node.kwargs["epsilon"]
kwargs["scale_ub"] = None # not used but required
kwargs["residual"] = None # not used but required
del kwargs["result_rms"] # not used in the fused op

fused_node = self.insert_auto_fn(
torch.ops._C.rms_norm_dynamic_fp8_quant.default,
torch.ops._C.rms_norm_dynamic_per_token_quant.default,
kwargs=kwargs)

getitem_nodes = self.insert_getitems(fused_node, (1, 2))
Expand Down Expand Up @@ -466,16 +436,17 @@ def replacement(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(
torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default,
torch.ops._C.rms_norm_dynamic_per_token_quant.default,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=self.epsilon)
epsilon=self.epsilon,
scale_ub=None,
residual=residual)

# result, residual, scale
return at[1], at[2], at[3] # TODO confirm signature
return at[1], at[3], at[2]

inputs = [
empty_fp8(5, 4), # result
Expand Down Expand Up @@ -508,22 +479,23 @@ def process(self):
# The auto_fn node returns a tuple (None, result, scale, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, ...) # noqa
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
# scale_node_new = at[3]
# scale_node_new = at[2]
# residual_node_new = at[3]
with self.inserting_after_match():
kwargs = self.match.kwargs.copy()

# Scalars cannot be inputs to the pattern
kwargs["epsilon"] = rms_node.kwargs["epsilon"]
kwargs["scale_ub"] = None # not used but required

fused_node = self.insert_auto_fn(
torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default,
torch.ops._C.rms_norm_dynamic_per_token_quant.default,
kwargs=kwargs)

getitem_ns = self.insert_getitems(fused_node, (1, 2, 3))
result_node_new, residual_node_new, scale_node_new = getitem_ns
result_node_new, scale_node_new, residual_node_new = getitem_ns

# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
Expand Down Expand Up @@ -588,12 +560,12 @@ def __init__(self, config: CompilationConfig.PassConfig):
self.patterns, self.record_match)

# Fuse rms_norm + dynamic_scaled_fp8_quant into
# rms_norm_dynamic_fp8_quant
# rms_norm_dynamic_per_token_quant
RMSNormDynamicFP8QuantPattern(epsilon).register(
self.patterns, self.record_match)

# Fuse fused_add_rms_norm + dynamic_scaled_fp8_quant into
# fused_add_rms_norm_dynamic_fp8_quant
# rms_norm_dynamic_per_token_quant
FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register(
self.patterns, self.record_match)

Expand Down

0 comments on commit 651ebdc

Please sign in to comment.