Skip to content

Commit

Permalink
Fix nested fqn discovery (pytorch#125957)
Browse files Browse the repository at this point in the history
I think I missed some fix!
Pull Request resolved: pytorch#125957
Approved by: https://github.com/sanketpurandare, https://github.com/janeyx99
  • Loading branch information
albanD authored and pytorchmergebot committed May 13, 2024
1 parent 9e1826d commit b620231
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
18 changes: 12 additions & 6 deletions test/test_module_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = Foo()
self.b = Foo()
self.b = torch.nn.ModuleDict({"nest": Foo()})
self.c = torch.nn.ModuleList([Foo()])

def forward(self, x):
return self.b(self.a(x))
x = self.c[0](x)
return self.b["nest"](self.a(x))

mod = Mod()

Expand All @@ -43,20 +45,24 @@ def forward(self, x):
self.assertEqual(
seen_fw,
[
({"Global", "Mod", "Mod.c.0"}, False),
({"Global", "Mod", "Mod.a"}, False),
({"Global", "Mod", "Mod.b"}, False),
({"Global", "Mod", "Mod.b.nest"}, False),
({"Global", "Mod", "Mod.c.0"}, False),
({"Global", "Mod", "Mod.a"}, False),
({"Global", "Mod", "Mod.b"}, False),
({"Global", "Mod", "Mod.b.nest"}, False),
],
)

self.assertEqual(
seen_bw,
[
({"Global", "Mod", "Mod.b"}, True),
({"Global", "Mod", "Mod.b.nest"}, True),
({"Global", "Mod", "Mod.a"}, True),
({"Global", "Mod", "Mod.b"}, True),
({"Global", "Mod", "Mod.c.0"}, True),
({"Global", "Mod", "Mod.b.nest"}, True),
({"Global", "Mod", "Mod.a"}, True),
({"Global", "Mod", "Mod.c.0"}, True),
],
)

Expand Down
4 changes: 3 additions & 1 deletion torch/utils/module_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def my_linear(m1, m2, bias):
def __init__(self):
self.parents = {"Global"}
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self._seen_modules = set()
self._seen_modules: weakref.WeakSet = weakref.WeakSet()
self._has_callback = False

def _maybe_set_engine_callback(self):
Expand Down Expand Up @@ -81,6 +81,8 @@ def _get_mod_name(self, mod):
if mod not in self._seen_modules:
for name, submod in mod.named_children():
self._known_modules[submod] = f"{mod_name}.{name}"
self._get_mod_name(submod)
self._seen_modules.add(mod)
return mod_name

def _get_append_fn(self, name, is_bw):
Expand Down

0 comments on commit b620231

Please sign in to comment.