Skip to content

Commit

Permalink
Merge branch 'main' into jjsjann123/dynamic_dict_getitem
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 authored Dec 1, 2024
2 parents fd1071b + fef423b commit d8e1b0e
Show file tree
Hide file tree
Showing 12 changed files with 493 additions and 58 deletions.
3 changes: 2 additions & 1 deletion thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,8 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
if tokens_per_sec:
print(f"Tokens/s: {tokens_per_sec:.02f}")
print(f"Tokens/s/GPU: {(tokens_per_sec / world_size):.02f}")
print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}")
if benchmark.throughput:
print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}")

if benchmark.dump_memory_snapshot:
file_name = f"{benchmark.model_name}_{benchmark.compile}_{benchmark.distributed_mode}"
Expand Down
3 changes: 3 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ def reverse_transform_state_dict_for_submodule(
) -> dict[str, Any]:
return state_dict

def __repr__(self) -> str:
return f"{self.__class__.__module__}.{self.__class__.__name__}()"


def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]:
"""computes a canonical ordering of proxies in the bound symbols based on the order of appearance
Expand Down
46 changes: 45 additions & 1 deletion thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import torch

from thunder.core.baseutils import run_once
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast
from thunder.core.utils import safe_zip
from thunder.dynamo.utils import recompile_graph, remove_empty_autocast, reproducer, CompilerType
from thunder.dynamo.splitter import _splitter

if TYPE_CHECKING:
from thunder.dynamo.utils import SubgraphInfo
from os import PathLike


@run_once
Expand Down Expand Up @@ -83,3 +85,45 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
split_module, subgraph_info = _splitter(gm, self._thunder_jit, self._torch_compile, sample_args)
self.subgraph_infos.append(subgraph_info)
return split_module

def save_reproducer_to_folder(self, reproducer_folder: str | PathLike, use_pytest_benchmark: bool = False):
"""
Save the reproducer script for the GraphModule executed by Thunder to the specified `reproducer_folder`.
Each saved script is named as "graph[graph_id]_thunder_[module_id]", where:
- `graph_id` indexes the graph generated by Dynamo, which is then passed to Thunder.
- `module_id` indexes the submodule split by the :func:`thunder.dynamo.utils._splitter`.
Args:
reproducer_folder (str | PathLike): The folder where the reproducer code will be written. Can be specified as an absolute or relative path.
use_pytest_benchmark (str): Determines the type of script to create:
- If use_pytest_benchmark=False: Creates a reproducer script.
- If use_pytest_benchmark=True: Creates a benchmark script to compare the reproducer's performance with other backends, including Torch eager, torch.compile, and torch.compile with `backend="eager"`.
"""
if not self.subgraph_infos:
raise TypeError(f"{self} doesn't seem to have been called yet.")

for graph_idx, subgraph_info in enumerate(self.subgraph_infos):
thunder_module_names = []
for node in subgraph_info.split_graph_module.graph.nodes:
target = node.target
if isinstance(target, str) and target.startswith("thunder_"):
thunder_module_names.append(target)
original_thunder_modules = (
m
for m, compiled_m in subgraph_info.submodule_to_compiled_functions.items()
if compiled_m.compiler == CompilerType.THUNDER
)
example_inputs = subgraph_info.thunder_compiled_fns_example_inputs
for cur_module, example_input, cur_name in safe_zip(
original_thunder_modules, example_inputs, thunder_module_names
):
reproducer(
cur_module,
self.thunder_options,
example_input,
reproducer_folder,
f"graph{graph_idx}_{cur_name}",
use_pytest_benchmark,
)
26 changes: 16 additions & 10 deletions thunder/dynamo/compiler_graph_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,25 @@ def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args):
if self.post_graph:
compiled_fn = self.post_graph(compiled_fn, sample_args)

with record_peak_allocated_memory(self.bench):
# This guard ensures compatibility with CPU-only PyTorch builds.
if torch.cuda.is_available():
with record_peak_allocated_memory(self.bench):
self.bench(compiled_fn, *sample_args)
else:
self.bench(compiled_fn, *sample_args)
# BenchmarkFixture.stats is created each time bench is called (ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L150)
# Adds the graph number, split module name and executor suffix to the name string
gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS
self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex_name}]"
assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info
assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info
# NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark.
# Therefore, we use the current stats name as a prefix to distinguish memory usage for each stats.
self.bench.extra_info[f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}"] = (
self.bench.extra_info.pop(MAX_ALLOCATED_MEMORY_KEYWORD)
)
self.bench.stats.name += f"-{gid_key}[{self.graph_idx}]-{module_name_key}[{name}]-{ex_key}[{ex_name}]"

if torch.cuda.is_available():
assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info
assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info
# NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark.
# Therefore, we use the current stats name as a prefix to distinguish memory usage for each stats.
self.bench.extra_info[f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}"] = (
self.bench.extra_info.pop(MAX_ALLOCATED_MEMORY_KEYWORD)
)

# when the graph is segmented, the self.bench run multiple times, pybenchmark throws an error:
# `FixtureAlreadyUsed("Fixture can only be used once. Previously it was used in %s mode." % self._mode)`
Expand Down Expand Up @@ -158,7 +164,7 @@ def has_checkpoint_node(g):
cur_nodes = cur_module.graph.nodes
# Greates random input values for the current module based on the faketensor 'example_value' of the placeholder node
placeholders = list(n for n in cur_nodes if n.op == "placeholder")
args = chain(*map(_get_example_inputs_from_placeholder, placeholders))
args = list(map(_get_example_inputs_from_placeholder, placeholders))
# Runs the benchmark on the original module with the generated random inputs
self.run_bench(compiled_functions_to_submodule[cur_module], target, *args)
self.graph_idx += 1
Expand Down
13 changes: 12 additions & 1 deletion thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import copy
from functools import partial

import torch
from torch.fx.passes.split_module import split_module
Expand All @@ -16,6 +17,7 @@
update_node_and_submodule,
recompile_graph,
checkpoint_converter,
_get_example_inputs_from_placeholder,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -124,8 +126,9 @@ def callback(node) -> int:
return partition_cnt

# There is a flip. Either from supported to unsupported or unsupported to supported.
if prev_value is not None:
partition_cnt += 1 # Bump the region cnt.
prev_value = is_thunder_supported
partition_cnt += 1 # Bump the region cnt.

if is_thunder_supported:
supported_partitions.add(partition_cnt)
Expand All @@ -142,11 +145,18 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:

# Call compile on the split region/s.
thunder_compiled_fns = []
example_input_metadatas = []
submodule_to_compiled_fns = {}
for node in split_gm.graph.nodes:
node_name = node.name
if is_thunder_supported_partition(node):
graph_module = getattr(split_gm, node.name)
# Record the input tensor metadata of the current module based on the faketensor 'example_value' of the placeholder node
placeholders = list(n for n in graph_module.graph.nodes if n.op == "placeholder")
example_input_metadata = map(
partial(_get_example_inputs_from_placeholder, only_metadata=True), placeholders
)
example_input_metadatas.append(list(example_input_metadata))
# Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators
checkpoint_converter(split_gm, graph_module)
jit_fn = thunder_jit(graph_module)
Expand Down Expand Up @@ -176,6 +186,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
original_split_gm,
split_gm,
thunder_compiled_fns,
example_input_metadatas,
submodule_to_compiled_fns,
split_reasons,
)
Loading

0 comments on commit d8e1b0e

Please sign in to comment.