Skip to content

Commit

Permalink
use print_readable for submodules (#74)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Dec 7, 2024
1 parent b9fa7a5 commit 7ffbe31
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_decompile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
jobs:
build:

runs-on: ubuntu-latest
runs-on: ubuntu-22.04 # ubuntu-latest does not support python 3.7
strategy:
fail-fast: false
matrix:
Expand Down
20 changes: 16 additions & 4 deletions depyf/explain/patched_lazy_format_graph_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
# update file path
filepath = inspect.getsourcefile(fn)
# try to use verbose code with type and shape annotations
src = "from __future__ import annotations\n" + \
gm._graph.python_code(root_module="self", verbose=True).src
use_gm = True

# use `print_readable` because it can include submodules
src = "from __future__ import annotations\nimport torch\n" + \
gm.print_readable(print_output=False)
src = src.replace("<lambda>", "GraphModule")
try:
compile(src, "noname", "exec")
except Exception as e:
Expand All @@ -38,13 +42,21 @@ def patched_lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
commented_src += "".join(["# " + line +
"\n" for line in src.splitlines()])
src = simple_code + commented_src
use_gm = False
if filepath is not None:
new_filepath = write_code_to_file_template(
src, os.path.dirname(filepath) + "/" + file_name + "." + "%s" + ".py")
scope = fn.__globals__
exec(compile(src, filename=new_filepath, mode="exec"), scope)
fn.__code__ = scope[fn.__name__].__code__
del scope[fn.__name__]
if use_gm:
import torch
classes = [v for v in scope.values() if isinstance(v, type) and issubclass(v, torch.nn.Module)]
assert len(classes) == 1
module_class = classes[0]
fn.__code__ = getattr(module_class, fn.__name__).__code__
else:
fn.__code__ = scope[fn.__name__].__code__
del scope[fn.__name__]

# =========================================
# original code of `lazy_format_graph_code`
Expand Down

0 comments on commit 7ffbe31

Please sign in to comment.