Skip to content

Commit

Permalink
Add graph as property of match, add comments, add utilities, extract …
Browse files Browse the repository at this point in the history
…ops to constants

Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Nov 12, 2024
1 parent 86324bf commit 2a90dd5
Showing 1 changed file with 73 additions and 51 deletions.
124 changes: 73 additions & 51 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import InductorPass, is_func
from vllm.logger import init_logger

from .config import CompilationConfig
from .inductor_pass import InductorPass, is_func

logger = init_logger(__name__)


Expand Down Expand Up @@ -65,16 +66,19 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:


class MultiOutputMatch(abc.ABC):
"""
This class provides utilities to process multi-output matches and
manually insert replacements.
This is necessary because the automatic replacement for multi-output
matches is broken: https://github.com/pytorch/pytorch/issues/137280
"""

def __init__(self, match: pm.Match):
self.match = match

@property
def nodes(self) -> List[torch.fx.Node]:
return self.match.nodes

@abstractmethod
def process(self, graph: torch.fx.Graph):
def process(self):
"""
Process a multi-output match and manually insert the replacement.
Expand Down Expand Up @@ -102,7 +106,21 @@ def process(self, graph: torch.fx.Graph):
"""
raise NotImplementedError

def inserting_after_match(self, graph: torch.fx.Graph):
@property
def nodes(self) -> List[torch.fx.Node]:
return self.match.nodes

@property
def graph(self) -> torch.fx.Graph:
return self.match.graph

def find_auto_fn(self, op) -> torch.fx.Node:
"""
Find the first auto_functionalized node with the given op in the match.
"""
return find_auto_fn(self.nodes, op)

def inserting_after_match(self):
"""
Insert nodes after the last node in the match.
This is done to avoid use-before-definition errors after inserting
Expand All @@ -111,51 +129,64 @@ def inserting_after_match(self, graph: torch.fx.Graph):

# match.nodes is not guaranteed to be sorted.
# Find the last node in the match.
for last_node_in_match in reversed(graph.nodes):
for last_node_in_match in reversed(self.graph.nodes):
if last_node_in_match in self.match.nodes:
break
else:
raise ValueError("No nodes in graph")

return graph.inserting_after(last_node_in_match)
return self.graph.inserting_after(last_node_in_match)

def insert_getitems(self, graph: torch.fx.Graph, tuple_node: torch.fx.Node,
indices: Tuple[int, ...]):
def insert_getitems(self, tuple_node: torch.fx.Node,
indices: Tuple[int, ...]) -> Tuple[torch.fx.Node, ...]:
"""
Insert operator.getitem nodes to extract elements from a tuple node.
:param graph: The graph to insert nodes into.
:param tuple_node: The tuple node to extract elements from.
:param indices: The indices of the elements to extract.
:return: Tuple of the new getitem nodes, corresponding to the indices.
"""
with graph.inserting_after(tuple_node):
return [
graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices
]
with self.graph.inserting_after(tuple_node):
return tuple(
self.graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices)

def insert_auto_fn(self, op, kwargs):
"""
Insert an auto_functionalized node with the given op and kwargs.
"""
return self.graph.call_function(auto_functionalized, (op, ),
kwargs=kwargs)


RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

QUANT_STATIC_FP8_OP = torch.ops._C.static_scaled_fp8_quant.default


class RMSNormQuantPattern:

def __init__(self, epsilon: float):
self.epsilon = epsilon


class RMSNormStaticFP8QuantPattern(RMSNormQuantPattern):

def register(self, pm_pass: PatternMatcherPass):
# Cannot use methods, as the self argument affects tracing
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.rms_norm.default,
at1 = auto_functionalized(RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon)
at2 = auto_functionalized(
torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at1[1],
scale=scale)
at2 = auto_functionalized(QUANT_STATIC_FP8_OP,
result=result,
input=at1[1],
scale=scale)

# result
return at2[1]
Expand Down Expand Up @@ -186,27 +217,23 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor,
pm_pass)


class FusedAddRMSNormQuantPattern:

def __init__(self, epsilon: float):
self.epsilon = epsilon
class FusedAddRMSNormStaticFP8QuantPattern(RMSNormQuantPattern):

def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):

def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default,
at = auto_functionalized(RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon)
at1 = auto_functionalized(
torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at[1],
scale=scale)
at1 = auto_functionalized(QUANT_STATIC_FP8_OP,
result=result,
input=at[1],
scale=scale)

# result, residual
return at1[1], at[2]
Expand Down Expand Up @@ -244,12 +271,10 @@ def replacement(result: torch.Tensor, input: torch.Tensor,

class Match(MultiOutputMatch):

def process(self, graph: torch.fx.Graph):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = find_auto_fn(self.match.nodes,
torch.ops._C.fused_add_rms_norm.default)
quant_node = find_auto_fn(
self.match.nodes, torch.ops._C.static_scaled_fp8_quant.default)
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(QUANT_STATIC_FP8_OP)

assert len(rms_node.users) == 2
assert len(quant_node.users) == 1
Expand All @@ -262,19 +287,17 @@ def process(self, graph: torch.fx.Graph):
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
with self.inserting_after_match(graph):
with self.inserting_after_match():
kwargs = self.match.kwargs.copy()

# Scalars cannot be inputs to the pattern
kwargs["epsilon"] = rms_node.kwargs["epsilon"]

fused_node = graph.call_function(
auto_functionalized,
(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
),
kwargs=kwargs)
fused_node = self.insert_auto_fn(
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
kwargs)

getitem_nodes = self.insert_getitems(graph, fused_node, (1, 2))
getitem_nodes = self.insert_getitems(fused_node, (1, 2))
result_node_new, residual_node_new = getitem_nodes

# Rebind the users of match getitem nodes to use the new nodes.
Expand Down Expand Up @@ -328,13 +351,13 @@ def __init__(self, config: CompilationConfig):
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static_scaled_fp8_quant into
# rms_norm_static_fp8_quant
RMSNormQuantPattern(epsilon).register(self.patterns)
RMSNormStaticFP8QuantPattern(epsilon).register(self.patterns)

# Fuse fused_add_rms_norm + static_scaled_fp8_quant into
# fused_add_rms_norm_static_fp8_quant
# Because pattern has 2 outputs, we need to manually process
# the match (see process_matches)
FusedAddRMSNormQuantPattern(epsilon).register(
FusedAddRMSNormStaticFP8QuantPattern(epsilon).register(
self.patterns, self.record_match)

# WARNING: This is a hack to clear the pattern matcher cache
Expand All @@ -352,11 +375,10 @@ def record_match(self, match: MultiOutputMatch) -> bool:
def process_matches(self, graph: torch.fx.Graph):
"""
Manually process multi-output matches and replace them with fused nodes.
This is necessary because the automatic replacement for multi-output
matches is broken: https://github.com/pytorch/pytorch/issues/137280
See MultiOutputMatch for more details.
"""
for match in self.matches:
match.process(graph)
match.process()

# Finally, remove matched nodes
graph.eliminate_dead_code()
Expand Down

0 comments on commit 2a90dd5

Please sign in to comment.