diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 6dde69effff..f510fb87522 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -595,6 +595,57 @@ class ModuleAttributePrecedence(ModuleAttributePrecedenceBase): return self.activation(self.linear(self.initializer + x)) * self.scale +class ModuleForwardHasGraphBreak(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = BasicModule() + self.layer2 = BasicModule() + self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule()) + self.layer4 = torch.nn.ModuleList( + [ + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + ] + ) + self.layer5 = torch.nn.ModuleDict( + { + "0": torch.nn.Linear(10, 10), + } + ) + self.scale = torch.randn(1, 10) + + def forward(self, x): + """ + This is used to test if the results of functions like `named_parameters` + can be reconstructed correctly after graph break. + + https://github.com/pytorch/torchdynamo/issues/1931 + """ + x = self.layer1(x) + params1 = dict(self.named_parameters()) + params2 = list(self.parameters()) + buffers1 = dict(self.named_buffers()) + buffers2 = list(self.buffers()) + modules1 = dict(self.named_modules()) + modules2 = list(self.modules()) + torch._dynamo.graph_break() + y = modules2 + y = modules1 + y = buffers2 + y = buffers1 + y = params2 + y = params1 + x = ( + self.layer2(x) + + y["layer3.1.linear1.weight"] + + y["layer4.2.weight"] + + y["layer5.0.weight"] + ) + return x * self.scale + + def make_test(fn, expected_ops=None): def test_fn(self): return torch._dynamo.testing.standard_test( @@ -646,6 +697,14 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): test_module_name_string = make_test(ModuleNameString()) test_module_attribute_precedence = make_test(ModuleAttributePrecedence()) + def test_module_forward_has_graph_break(self): + m = ModuleForwardHasGraphBreak() + x = torch.rand([10, 10]) + ref = m(x) + opt_m = torch._dynamo.optimize("eager")(m) + res = opt_m(x) + self.assertTrue(torch.allclose(ref, res)) + def test_unsupportedmethod(self): m = UnsupportedMethodCall() i = torch.randn(10) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 7dbd0ba331f..454daae1d1f 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -1,7 +1,6 @@ import functools import inspect import itertools -import re import types from contextlib import contextmanager from typing import Dict, List @@ -283,18 +282,15 @@ class NNModuleVariable(VariableTracker): bound_args = bound_args.arguments return {k: bound_args[k] for k in names} - def wrap_values(items, getsource=AttrSource): + def wrap_values(items): result = [] for name, submod in items: - # layer.0.foo => layer[0].foo - name = re.sub(r"[.]([0-9]+)([.]|$)", r"[\1]\2", name) - src = NNModuleSource(getsource(self.source, name)) result.append( tx.output.register_attr_or_module( submod, key, name, - source=src, + source=NNModuleSource(gen_source(self.source, name)), **options, ) ) @@ -308,12 +304,21 @@ class NNModuleVariable(VariableTracker): obj, key, name, - source=NNModuleSource(AttrSource(self.source, name)), + source=NNModuleSource(gen_source(self.source, name)), **options, ), ] ) + def gen_source(source, name): + name_split = name.split(".") + if name_split[0] == "": + return source + while len(name_split) > 0: + x = name_split.pop(0) + source = AttrSource(source, x) + return source + if name == "children": assert not (args or kwargs) return wrap_values(module.named_children()) @@ -344,7 +349,7 @@ class NNModuleVariable(VariableTracker): return wrap_values(module.named_parameters(**get_kwargs("recurse"))) elif name == "values": assert not (args or kwargs) - return wrap_values(module.items(), GetItemSource) + return wrap_values(module.items()) elif name == "items": assert not (args or kwargs) result = []