From b6202313781c0445a66950663c4771ba1209eaa9 Mon Sep 17 00:00:00 2001 From: albanD Date: Mon, 13 May 2024 18:24:51 +0000 Subject: [PATCH] Fix nested fqn discovery (#125957) I think I missed some fix! Pull Request resolved: https://github.com/pytorch/pytorch/pull/125957 Approved by: https://github.com/sanketpurandare, https://github.com/janeyx99 --- test/test_module_tracker.py | 18 ++++++++++++------ torch/utils/module_tracker.py | 4 +++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/test/test_module_tracker.py b/test/test_module_tracker.py index e465b128edbea..09b550354282d 100644 --- a/test/test_module_tracker.py +++ b/test/test_module_tracker.py @@ -25,10 +25,12 @@ class Mod(torch.nn.Module): def __init__(self): super().__init__() self.a = Foo() - self.b = Foo() + self.b = torch.nn.ModuleDict({"nest": Foo()}) + self.c = torch.nn.ModuleList([Foo()]) def forward(self, x): - return self.b(self.a(x)) + x = self.c[0](x) + return self.b["nest"](self.a(x)) mod = Mod() @@ -43,20 +45,24 @@ def forward(self, x): self.assertEqual( seen_fw, [ + ({"Global", "Mod", "Mod.c.0"}, False), ({"Global", "Mod", "Mod.a"}, False), - ({"Global", "Mod", "Mod.b"}, False), + ({"Global", "Mod", "Mod.b.nest"}, False), + ({"Global", "Mod", "Mod.c.0"}, False), ({"Global", "Mod", "Mod.a"}, False), - ({"Global", "Mod", "Mod.b"}, False), + ({"Global", "Mod", "Mod.b.nest"}, False), ], ) self.assertEqual( seen_bw, [ - ({"Global", "Mod", "Mod.b"}, True), + ({"Global", "Mod", "Mod.b.nest"}, True), ({"Global", "Mod", "Mod.a"}, True), - ({"Global", "Mod", "Mod.b"}, True), + ({"Global", "Mod", "Mod.c.0"}, True), + ({"Global", "Mod", "Mod.b.nest"}, True), ({"Global", "Mod", "Mod.a"}, True), + ({"Global", "Mod", "Mod.c.0"}, True), ], ) diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index 078effe99aefb..b79d1432bb1bb 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -52,7 +52,7 @@ def my_linear(m1, m2, bias): def __init__(self): self.parents = {"Global"} self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() - self._seen_modules = set() + self._seen_modules: weakref.WeakSet = weakref.WeakSet() self._has_callback = False def _maybe_set_engine_callback(self): @@ -81,6 +81,8 @@ def _get_mod_name(self, mod): if mod not in self._seen_modules: for name, submod in mod.named_children(): self._known_modules[submod] = f"{mod_name}.{name}" + self._get_mod_name(submod) + self._seen_modules.add(mod) return mod_name def _get_append_fn(self, name, is_bw):