# Owner(s): ["module: unknown"] from copy import copy import torch from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo from torch.utils.module_tracker import ModuleTracker class TestModuleTracker(TestCase): # "https://github.com/pytorch/pytorch/issues/127112 @xfailIfTorchDynamo def test_module_hierarchy(self): seen_fw = [] seen_bw = [] class Foo(torch.nn.Module): def forward(self, x): x = x["a"].relu_() seen_fw.append((copy(tracker.parents), tracker.is_bw)) x.register_hook( lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw)) ) return {"a": torch.mm(x, x)} class Mod(torch.nn.Module): def __init__(self): super().__init__() self.a = Foo() self.b = torch.nn.ModuleDict({"nest": Foo()}) self.c = torch.nn.ModuleList([Foo()]) def forward(self, x): x = self.c[0](x) return self.b["nest"](self.a(x)) mod = Mod() with ModuleTracker() as tracker: mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ "a" ].sum().backward() mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ "a" ].sum().backward() self.assertEqual( seen_fw, [ ({"Global", "Mod", "Mod.c.0"}, False), ({"Global", "Mod", "Mod.a"}, False), ({"Global", "Mod", "Mod.b.nest"}, False), ({"Global", "Mod", "Mod.c.0"}, False), ({"Global", "Mod", "Mod.a"}, False), ({"Global", "Mod", "Mod.b.nest"}, False), ], ) self.assertEqual( seen_bw, [ ({"Global", "Mod", "Mod.b.nest"}, True), ({"Global", "Mod", "Mod.a"}, True), ({"Global", "Mod", "Mod.c.0"}, True), ({"Global", "Mod", "Mod.b.nest"}, True), ({"Global", "Mod", "Mod.a"}, True), ({"Global", "Mod", "Mod.c.0"}, True), ], ) def test_bw_detection(self): mod = torch.nn.Linear(2, 2) with ModuleTracker() as tracker: mod(torch.rand(2, requires_grad=True)).sum().backward() self.assertFalse(tracker.is_bw) self.assertEqual(tracker.parents, {"Global"}) if __name__ == "__main__": run_tests()