Skip to content

Commit

Permalink
TEMP: dynamic quant
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Nov 12, 2024
1 parent 2a90dd5 commit 7d1adbf
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 10 deletions.
29 changes: 20 additions & 9 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@

class TestModel(torch.nn.Module):

def __init__(self, hidden_size: int, eps: float, *args, **kwargs):
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
else:
self.scale = [None for _ in range(2)]
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
Expand All @@ -29,11 +34,11 @@ def forward(self, x):
resid = torch.relu(x)
y = self.norm[0](x)

x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1])
x2 = apply_fp8_linear(y, self.w[0], self.wscale[0], self.scale[0])
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)

x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3])
x3 = apply_fp8_linear(y2, self.w[1], self.wscale[1], self.scale[1])
y3, resid = self.norm[2](x3, resid) # use resid here
return y3

Expand All @@ -48,15 +53,16 @@ def forward(self, x):
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)

# Reshape pass is needed for the fusion pass to work
backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps)
model = TestModel(hidden_size, eps, static)

# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
Expand All @@ -74,9 +80,14 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes

rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
if static:
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default # noqa: E501
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
else:
rms_quant = torch.ops._C.rms_norm_dynamic_fp8_quant.default
add_rms_quant = torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default # noqa: E501
fp8_quant = torch.ops._C.dynamic_scaled_fp8_quant.default

# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
Expand Down
8 changes: 7 additions & 1 deletion vllm/compilation/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,26 @@ def __call__(self, graph: torch.fx.Graph):
self.insert_defunctionalized(graph, node)
self._remove(node)

# These 2 replacements avoid the most copies for LLaMa.
# rms_norm replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: 'input', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'residual', 3: 'scale'}
self.defunctionalize(graph, node, mutated_args)

elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default
]:
mutated_args = {1: 'result'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_fp8_quant.default:
mutated_args = {1: 'result', 2: 'scale'}
self.defunctionalize(graph, node, mutated_args)

elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: 'out'}
Expand Down
238 changes: 238 additions & 0 deletions vllm/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,41 @@
logger = init_logger(__name__)


# TODO temp
@torch.library.custom_op("_C::rms_norm_dynamic_fp8_quant",
mutates_args=("result", "scale"))
def rms_norm_dynamic_fp8_quant(result: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, scale: torch.Tensor,
epsilon: float) -> None:
result_rms = torch.empty_like(input)
torch.ops._C.rms_norm(result_rms, input, weight, epsilon)
torch.ops._C.dynamic_scaled_fp8_quant(result, result_rms, scale)


@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant")
def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, epsilon: float):
return None


@torch.library.custom_op("_C::fused_add_rms_norm_dynamic_fp8_quant",
mutates_args=("result", "residual", "scale"))
def fused_add_rms_norm_dynamic_fp8_quant(result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
torch.ops._C.dynamic_scaled_fp8_quant(result, input, scale)


@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant")
def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, epsilon: float):
return None


def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")

Expand Down Expand Up @@ -163,6 +198,7 @@ def insert_auto_fn(self, op, kwargs):
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default

QUANT_STATIC_FP8_OP = torch.ops._C.static_scaled_fp8_quant.default
QUANT_DYNAMIC_FP8_OP = torch.ops._C.dynamic_scaled_fp8_quant.default


class RMSNormQuantPattern:
Expand Down Expand Up @@ -312,6 +348,198 @@ def process(self):
fused_node.meta["val"] = (None, quant_tup[1], rms_tup[2])


class RMSNormDynamicFP8QuantPattern(RMSNormQuantPattern):

def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):

def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon)
at2 = auto_functionalized(QUANT_DYNAMIC_FP8_OP,
result=result,
input=at1[1],
scale=scale)

# result, scale
return at2[1], at2[2]

def replacement(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(
torch.ops._C.rms_norm_static_fp8_quant.default,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon)

# result, scale
return at[1], at[2]

inputs = [
empty_fp8(5, 4), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]

pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(self.Match(m)))

class Match(MultiOutputMatch):

def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_OP)
quant_node = self.find_auto_fn(QUANT_DYNAMIC_FP8_OP)

assert len(rms_node.users) == 1
assert len(quant_node.users) == 2

# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and scale.
# The auto_fn node returns a tuple of (None, result, scale).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
with self.inserting_after_match():
kwargs = self.match.kwargs.copy()

# Scalars cannot be inputs to the pattern
kwargs["epsilon"] = rms_node.kwargs["epsilon"]
del kwargs["result_rms"] # not used in the fused op

fused_node = self.insert_auto_fn(
torch.ops._C.rms_norm_dynamic_fp8_quant.default,
kwargs=kwargs)

getitem_nodes = self.insert_getitems(fused_node, (1, 2))
result_node_new, scale_node_new = getitem_nodes

# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new)

# Finally, fix meta["val"] for de-functionalization.
# See MultiOutputMatch.process for more details.
# Result of fused node is (None, result, scale)
fused_node.meta["val"] = quant_node.meta["val"]


class FusedAddRMSNormDynamicFP8QuantPattern(RMSNormQuantPattern):

def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):

def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon)
at1 = auto_functionalized(QUANT_DYNAMIC_FP8_OP,
result=result,
input=at[1],
scale=scale)

# result, residual, scale
return at1[1], at[2], at1[2]

def replacement(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(
torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=self.epsilon)

# result, residual, scale
return at[1], at[2], at[3] # TODO confirm signature

inputs = [
empty_fp8(5, 4), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]

pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(self.Match(m)))

class Match(MultiOutputMatch):

def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(QUANT_DYNAMIC_FP8_OP)

assert len(rms_node.users) == 2
assert len(quant_node.users) == 2

# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract result, scale, and residual.
# The auto_fn node returns a tuple (None, result, scale, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
# scale_node_new = at[3]
with self.inserting_after_match():
kwargs = self.match.kwargs.copy()

# Scalars cannot be inputs to the pattern
kwargs["epsilon"] = rms_node.kwargs["epsilon"]

fused_node = self.insert_auto_fn(
torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default,
kwargs=kwargs)

getitem_ns = self.insert_getitems(fused_node, (1, 2, 3))
result_node_new, residual_node_new, scale_node_new = getitem_ns

# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
find_getitem(quant_node, 2).replace_all_uses_with(scale_node_new)

# Finally, fix meta["val"] for de-functionalization.
# See MultiOutputMatch.process for more details.
rms_tup, quant_tup = rms_node.meta["val"], quant_node.meta["val"]
# Result of fused node is (None, result, scale, residual)
fused_node.meta["val"] = (None, quant_tup[1], quant_tup[2],
rms_tup[2])


class FusionPass(InductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
Expand Down Expand Up @@ -360,6 +588,16 @@ def __init__(self, config: CompilationConfig):
FusedAddRMSNormStaticFP8QuantPattern(epsilon).register(
self.patterns, self.record_match)

# Fuse rms_norm + dynamic_scaled_fp8_quant into
# rms_norm_dynamic_fp8_quant
RMSNormDynamicFP8QuantPattern(epsilon).register(
self.patterns, self.record_match)

# Fuse fused_add_rms_norm + dynamic_scaled_fp8_quant into
# fused_add_rms_norm_dynamic_fp8_quant
FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register(
self.patterns, self.record_match)

# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
Expand Down

0 comments on commit 7d1adbf

Please sign in to comment.