Skip to content

Commit

Permalink
improve fusion stability (#899)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Aug 1, 2024
1 parent 9ce123b commit fa307a5
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 25 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ looseversion ==1.3.0
lightning-utilities >=0.7.0
numpy >=1.23.0,<2 # not yet ready for numpy 2
igraph >=0.10.4
optree >=0.9.2
optree >=0.11.0
opt_einsum >= 3.3.0
mpmath <1.4.0 # todo: teporarl pin for `NameError: name '_C' is not defined`
2 changes: 2 additions & 0 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def tree_flatten(args, namespace=""):
# while generating the split functions.
tree_map = partial(optree.tree_map, none_is_leaf=True, namespace=OPTREE_NAMESPACE)

tree_iter = partial(optree.tree_iter, none_is_leaf=True, namespace=OPTREE_NAMESPACE)


def tree_unflatten(values, spec):
return optree.tree_unflatten(spec, values)
Expand Down
9 changes: 6 additions & 3 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.core.symbol import has_tags
from thunder.core.trace import from_trace, TraceCtx, TraceProvenance
from thunder.core.transform_common import dce
from thunder.core.transform_common import dce, order_proxies
from thunder.executors.passes import update_fusion_call_ctx


Expand Down Expand Up @@ -126,7 +126,8 @@ def apply_rematerialization_for_producer(
all_produced_vars = tuple(chain.from_iterable((y for y in x.flat_proxy_outs) for x in producer.subsymbols))
# Choose the new producer's output from all the produced variables.
new_producer_output = tuple(x for x in all_produced_vars if x.name in new_producer_output_names)
new_producer_output = tuple(sorted(new_producer_output, key=lambda x: x.name))
proxy_order = order_proxies(producer.subsymbols)
new_producer_output = tuple(sorted(new_producer_output, key=lambda p: proxy_order[p.name]))
new_producer = replace(producer, output=new_producer_output)
return new_producer

Expand Down Expand Up @@ -179,7 +180,9 @@ def apply_rematerialization_for_consumer(
new_consumer_args += tuple(
x for x in producer.args if x.name in all_args and x.name not in (x.name for x in new_consumer_args)
)
new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: x.name))

proxy_order = order_proxies(new_subsymbols)
new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols)
return new_consumer

Expand Down
60 changes: 58 additions & 2 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from collections import defaultdict
from collections.abc import Sequence
from collections import defaultdict
from itertools import filterfalse
from itertools import filterfalse, chain
from functools import partial

import thunder
import thunder.core.prims as prims
from thunder.core.baseutils import BoundSymbolInterface
from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify
from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten
from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace, tracectx
from thunder.core.utils import ProxyDict, producers, check, consumers
Expand Down Expand Up @@ -865,3 +865,59 @@ def computation(x):

functionalized_computation_trace.bound_symbols = functionalized_bsyms
return [intermediate_trace, functionalized_computation_trace]


def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]:
"""computes a canonical ordering of proxies in the bound symbols based on the order of appearance
note that it would not cover unused inputs when applied to traces.bound_symbols
"""
counter = 0
proxy_order: dict[str, int] = {} # names to order

def process_bound_symbols(bound_symbols):
nonlocal counter
for bsym in bound_symbols:
if len(bsym.subsymbols) > 0:
process_bound_symbols(bsym.subsymbols)
for p in tree_iter((bsym.args, bsym.kwargs, bsym.output)): # should kwargs be sorted by name?
if isinstance(p, thunder.Proxy) and p.name not in proxy_order:
counter += 1
proxy_order[p.name] = counter

process_bound_symbols(bsyms)

return proxy_order


def canonicalize_proxies(bsyms: Sequence[BoundSymbol]) -> Sequence[BoundSymbol]:
output = []
counter = 0

proxymap: dict[str, thunder.Proxy] = {}

def map_proxy(p):
nonlocal counter
if isinstance(p, thunder.Proxy):
if p.name in proxymap:
return proxymap[p.name]
np = p.replace(name=f"p{counter}")
counter += 1
proxymap[p.name] = np
return np
return p

def process_bound_symbols(src_bound_symbols, target_bound_symbols):
for bsym in src_bound_symbols:
new_subsymbols = []
if len(bsym.subsymbols) > 0:
process_bound_symbols(bsym.subsymbols, new_subsymbols)
new_args = tree_map(map_proxy, bsym.args)
new_kwargs = tree_map(map_proxy, bsym.kwargs) # should this be sorted by key word?
new_output = tree_map(map_proxy, bsym.output)
new_bsym = bsym.from_bsym(output=new_output, args=new_args, kwargs=new_kwargs, subsymbols=new_subsymbols)
target_bound_symbols.append(new_bsym)

with thunder.core.trace.tracectx(thunder.TraceCtx()):
process_bound_symbols(bsyms, output)

return output
11 changes: 11 additions & 0 deletions thunder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,17 @@ def difference(self, other: "_OrderedSet") -> Self:
def add(self, x: T | T1):
self.d[self.canonicalize(x)] = None

def discard(self, x: T | T1):
c = self.canonicalize(x)
if c in self.d:
del self.d[c]

def issubset(self, other):
return all((e in other) for e in self)

def union(self, *others: "Sequence[_OrderedSet]") -> Self:
return self.__class__(itertools.chain(self, *others))

def update(self, x: Iterable[T | T1]) -> None:
for i in x:
self.d.setdefault(self.canonicalize(i), None)
Expand Down
4 changes: 2 additions & 2 deletions thunder/executors/cudagraphex.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def __init__(self, name: Hashable):
super().__init__(name, version=torch.version.cuda)

def fuse(self, region: Region, fusion_counter: int, num_static_inputs: None | int = None) -> BoundSymbol:
inputs = [unvariableify(inp) for inp in sorted(region.inputs, key=lambda var: var.proxy.name)]
outputs = [unvariableify(out) for out in sorted(region.outputs, key=lambda var: var.proxy.name)]
inputs = [unvariableify(inp) for inp in region.inputs]
outputs = [unvariableify(out) for out in region.outputs]

fusion_name = f"CUDAGraph{fusion_counter}"
fusion_callable: Callable = make_callable(f"{fusion_name}_fn", region.bound_symbols, inputs, outputs)
Expand Down
4 changes: 2 additions & 2 deletions thunder/executors/data_dependent_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(self, ID: int, group_bsyms: list[BoundSymbol], group_indices: list[
self.stop = stop
self.group_bsyms = group_bsyms
self.group_indices = group_indices
self.parents: set[Node] = set()
self.children: set[Node] = set()
self.parents: utils.OrderedSet[Node] = utils.OrderedSet()
self.children: utils.OrderedSet[Node] = utils.OrderedSet()

def __repr__(self) -> str:
s = f"node ID {self.ID} : "
Expand Down
7 changes: 2 additions & 5 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,11 +611,8 @@ def _dce_bsyms(self, input_list, output, bsyms: list[BoundSymbol]) -> list[Bound
return list(filter(lambda x: x.sym != prims.python_return, trace.bound_symbols))

def fuse(self, region: Region, fusion_counter: int) -> BoundSymbol:
def keyfn(x: Variable) -> str:
return x.proxy.name

sorted_unique_inputs: list[Proxy] = list(unvariableify(x) for x in sorted(region.inputs, key=keyfn))
sorted_unique_outputs: list[Proxy] = list(unvariableify(x) for x in sorted(region.outputs, key=keyfn))
sorted_unique_inputs: list[Proxy] = [unvariableify(x) for x in region.inputs]
sorted_unique_outputs: list[Proxy] = [unvariableify(x) for x in region.outputs]

flattened_bsyms: list[BoundSymbol] = []
for bsym in region.bound_symbols:
Expand Down
4 changes: 2 additions & 2 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def fuse(self, region: Region, fusion_counter: int) -> BoundSymbol:
def keyfn(x: Variable) -> str:
return x.proxy.name

sorted_unique_inputs: list[Proxy] = list(unvariableify(x) for x in sorted(region.inputs, key=keyfn))
sorted_unique_outputs: list[Proxy] = list(unvariableify(x) for x in sorted(region.outputs, key=keyfn))
sorted_unique_inputs: list[Proxy] = [unvariableify(x) for x in region.inputs]
sorted_unique_outputs: list[Proxy] = [unvariableify(x) for x in region.outputs]

compiled: Callable = make_compiled(region.bound_symbols, sorted_unique_inputs, sorted_unique_outputs)

Expand Down
13 changes: 9 additions & 4 deletions thunder/executors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten
from thunder.core.proxies import Variable, variableify, Proxy, unvariableify
from thunder.core.prims import PrimIDs
from thunder.core.transform_common import order_proxies

# TODO Make these tags
comment_symbols = {
Expand Down Expand Up @@ -51,25 +52,29 @@ def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]):
# Updates what this region consumes, skipping symbols that never consume anything
consumes.update(variableify(x) for x in bsym.flat_args if isinstance(x, Proxy))

self.inputs = set()
self.outputs = set()
inputs = set()
outputs = set()

# Inputs are things which this consumes which are produced before it
for x in consumes:
x = unvariableify(x)

if producers[x] not in self.bound_symbols:
self.inputs.add(variableify(x))
inputs.add(variableify(x))

# Outputs are things this produces that are consumed after it
for x in produces:
x = unvariableify(x)
consumed_by = consumers.get(x, ())
for bsym in consumed_by:
if bsym not in self.bound_symbols:
self.outputs.add(variableify(x))
outputs.add(variableify(x))
break

proxy_order = order_proxies(self.bound_symbols)
self.inputs = utils.OrderedSet(sorted(inputs, key=lambda p: proxy_order[p.proxy.name]))
self.outputs = utils.OrderedSet(sorted(outputs, key=lambda p: proxy_order[p.proxy.name]))

def __repr__(self) -> str:
s = f"[Region:"

Expand Down
42 changes: 42 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial, reduce
from itertools import product
import dataclasses
import re

import pytest
import torch
Expand Down Expand Up @@ -3060,3 +3061,44 @@ def fn(tensors):
actual.backward(cotangent)

torch.testing.assert_close(tuple(t.grad for t in tensors), tuple(t.grad for t in tensors_jit))


@requiresCUDA
def test_bound_symbol_sort_stability():
class LlamaMLPLike(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc_1 = torch.nn.Linear(32, 32)
self.fc_2 = torch.nn.Linear(32, 32)
self.proj = torch.nn.Linear(32, 32)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
return self.proj(x)

with torch.device("cuda"):
mlp = torch.nn.Sequential(*[LlamaMLPLike() for _ in range(16)]).requires_grad_(False)
j = thunder.jit(mlp)
j(torch.randn(32, 32, device="cuda"))
lt = thunder.last_traces(j)[-1]
assert all(
(i % 2 + 1 == i_2)
for i, i_2 in enumerate(
[
int(s.args[1].name.split("_")[-2])
for s in lt.bound_symbols
if s.sym.name == "linear" and "fc" in s.args[1].name
]
)
)

fusions = examine.get_fusion_symbols(lt)

no_number = partial(re.sub, r"nvFusion\d+", "nvFusion")
fusions = [no_number(str(thunder.core.transform_common.canonicalize_proxies([f])[0])) for f in fusions]

f0 = fusions[0]
for f in fusions[1:]:
assert f0 == f
8 changes: 4 additions & 4 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def foo(a, x):
assert len(fusions) == 2

# Verifies that the nvFusion inputs and outputs are updated properly
t0 = fusions[0].output
assert fusions[1].args[0].name == "t0"
assert t0[0].name == "t0"
t0 = fusions[0].output[0]
assert fusions[1].args[2].name == "t0"
assert t0.name == "t0"
assert extrace.output[0].name == "t0"
assert len(fusions[0].subsymbols) == 3

Expand Down Expand Up @@ -316,7 +316,7 @@ def func(w, x, y, z):
assert len(fusion_bsyms) == 1
nvf_0 = fusion_bsyms[0]

assert [t.name for t in tree_flatten(nvf_0.args)[0]] == ["t0", "w", "z"]
assert [t.name for t in tree_flatten(nvf_0.args)[0]] == ["t0", "z", "w"]
assert len(nvf_0.subsymbols) == 7
assert [t.name for t in tree_flatten(nvf_0.output)[0]] == ["t13"]

Expand Down

0 comments on commit fa307a5

Please sign in to comment.