[Dynamo] Fix source/reconstruction bugs in NNModule named_* calls (#89729)

Fixes https://github.com/pytorch/torchdynamo/issues/1931

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89729
Approved by: https://github.com/ezyang
This commit is contained in:
Yanbo Liang 2022-11-30 06:05:44 +00:00 committed by PyTorch MergeBot
parent 447283752c
commit d88b555577
2 changed files with 72 additions and 8 deletions

View File

@ -595,6 +595,57 @@ class ModuleAttributePrecedence(ModuleAttributePrecedenceBase):
return self.activation(self.linear(self.initializer + x)) * self.scale
class ModuleForwardHasGraphBreak(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule())
self.layer4 = torch.nn.ModuleList(
[
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
]
)
self.layer5 = torch.nn.ModuleDict(
{
"0": torch.nn.Linear(10, 10),
}
)
self.scale = torch.randn(1, 10)
def forward(self, x):
"""
This is used to test if the results of functions like `named_parameters`
can be reconstructed correctly after graph break.
https://github.com/pytorch/torchdynamo/issues/1931
"""
x = self.layer1(x)
params1 = dict(self.named_parameters())
params2 = list(self.parameters())
buffers1 = dict(self.named_buffers())
buffers2 = list(self.buffers())
modules1 = dict(self.named_modules())
modules2 = list(self.modules())
torch._dynamo.graph_break()
y = modules2
y = modules1
y = buffers2
y = buffers1
y = params2
y = params1
x = (
self.layer2(x)
+ y["layer3.1.linear1.weight"]
+ y["layer4.2.weight"]
+ y["layer5.0.weight"]
)
return x * self.scale
def make_test(fn, expected_ops=None):
def test_fn(self):
return torch._dynamo.testing.standard_test(
@ -646,6 +697,14 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
test_module_name_string = make_test(ModuleNameString())
test_module_attribute_precedence = make_test(ModuleAttributePrecedence())
def test_module_forward_has_graph_break(self):
m = ModuleForwardHasGraphBreak()
x = torch.rand([10, 10])
ref = m(x)
opt_m = torch._dynamo.optimize("eager")(m)
res = opt_m(x)
self.assertTrue(torch.allclose(ref, res))
def test_unsupportedmethod(self):
m = UnsupportedMethodCall()
i = torch.randn(10)

View File

@ -1,7 +1,6 @@
import functools
import inspect
import itertools
import re
import types
from contextlib import contextmanager
from typing import Dict, List
@ -283,18 +282,15 @@ class NNModuleVariable(VariableTracker):
bound_args = bound_args.arguments
return {k: bound_args[k] for k in names}
def wrap_values(items, getsource=AttrSource):
def wrap_values(items):
result = []
for name, submod in items:
# layer.0.foo => layer[0].foo
name = re.sub(r"[.]([0-9]+)([.]|$)", r"[\1]\2", name)
src = NNModuleSource(getsource(self.source, name))
result.append(
tx.output.register_attr_or_module(
submod,
key,
name,
source=src,
source=NNModuleSource(gen_source(self.source, name)),
**options,
)
)
@ -308,12 +304,21 @@ class NNModuleVariable(VariableTracker):
obj,
key,
name,
source=NNModuleSource(AttrSource(self.source, name)),
source=NNModuleSource(gen_source(self.source, name)),
**options,
),
]
)
def gen_source(source, name):
name_split = name.split(".")
if name_split[0] == "":
return source
while len(name_split) > 0:
x = name_split.pop(0)
source = AttrSource(source, x)
return source
if name == "children":
assert not (args or kwargs)
return wrap_values(module.named_children())
@ -344,7 +349,7 @@ class NNModuleVariable(VariableTracker):
return wrap_values(module.named_parameters(**get_kwargs("recurse")))
elif name == "values":
assert not (args or kwargs)
return wrap_values(module.items(), GetItemSource)
return wrap_values(module.items())
elif name == "items":
assert not (args or kwargs)
result = []