Skip to content

Commit

Permalink
Switch OSS dashboard to use aoti_compile_and_package (#139597)
Browse files Browse the repository at this point in the history
Summary:
Reland pytorch/pytorch#139154

X-link: pytorch/pytorch#139597
Approved by: https://github.com/angelayi

Reviewed By: ZainRizvi

Differential Revision: D65455707

Pulled By: desertfire

fbshipit-source-id: 691882e606754fc04cb826a14bdfe94cb465ece8
  • Loading branch information
desertfire authored and facebook-github-bot committed Nov 5, 2024
1 parent 86a366e commit 4a42e06
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,9 +987,7 @@ def maybe_mark_profile(*args, **kwargs):

with maybe_profile(args.export_profiler_trace) as p:
if args.export_aot_inductor:
frozen_model_iter_fn = export_aot_inductor(
model, example_inputs, args.devices[0]
)
frozen_model_iter_fn = export_aot_inductor(model, example_inputs)
else:
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)

Expand Down Expand Up @@ -1487,7 +1485,7 @@ class AOTInductorModelCache:
cache = {}

@classmethod
def load(cls, model, example_inputs, device):
def load(cls, model, example_inputs):
import torch._inductor
import torch.export._trace
from torch.export.dynamic_shapes import _tree_map_with_path
Expand Down Expand Up @@ -1515,18 +1513,19 @@ def load(cls, model, example_inputs, device):
_produce_dynamic_shapes_for_export, combined_args
)

gm = torch.export._trace._export(
ep = torch.export.export(
model,
example_args,
example_kwargs,
dynamic_shapes=dynamic_shapes,
pre_dispatch=True,
strict=False,
).module()
)
with torch.no_grad():
so_path = torch._inductor.aot_compile(gm, example_args, example_kwargs) # type: ignore[arg-type]
package_path = torch._inductor.aoti_compile_and_package(
ep, example_args, example_kwargs
) # type: ignore[arg-type]

cls.cache[key] = torch._export.aot_load(so_path, device)
cls.cache[key] = torch._inductor.aoti_load_package(package_path)

return cls.cache[key]

Expand Down Expand Up @@ -1554,8 +1553,8 @@ def opt_export(_, example_inputs):
return opt_export


def export_aot_inductor(model, example_inputs, device):
optimized = AOTInductorModelCache.load(model, example_inputs, device)
def export_aot_inductor(model, example_inputs):
optimized = AOTInductorModelCache.load(model, example_inputs)

def opt_aot_inductor(_, example_inputs, collect_outputs=False):
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
Expand Down Expand Up @@ -4585,9 +4584,7 @@ def run(runner, args, original_dir=None):
elif args.backend or args.export_aot_inductor:
if args.export_aot_inductor:
assert not args.training, "AOTInductor only supports inference"
optimize_ctx = functools.partial(
export_aot_inductor, device=args.devices[0]
)
optimize_ctx = functools.partial(export_aot_inductor)

# AOTInductor doesn't support control flow yet
runner.skip_models.update(runner.skip_models_due_to_control_flow)
Expand Down

0 comments on commit 4a42e06

Please sign in to comment.