Skip to content

Commit

Permalink
More descriptive graph diagram names in svg (pytorch#106146)
Browse files Browse the repository at this point in the history
  • Loading branch information
eellison authored and pytorchmergebot committed Jul 28, 2023
1 parent 5237ed5 commit 8f4d8b3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ class trace:
output_code = True

# SVG figure showing post-fusion graph
graph_diagram = False
graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1"

# Store cProfile (see snakeviz to view)
compile_profile = False
Expand Down
8 changes: 6 additions & 2 deletions torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def func1(*args):

FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])

func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
buf_to_fx_node = {}
graph = torch.fx.Graph()
first_node = None
Expand All @@ -122,7 +121,12 @@ def func1(*args):
group = snode.group
else:
raise RuntimeError("Unknown node type")
node_func = func_dict[node_type]

fused_name = torch._inductor.utils.get_fused_kernel_name(
snode.get_nodes(), "original_aten"
)
func_name = f"{node_type}: {fused_name}"
node_func = get_fake_func(func_name)
fx_node = graph.call_function(node_func, args=(), kwargs=None)

def in_output(snode):
Expand Down

0 comments on commit 8f4d8b3

Please sign in to comment.