mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Fix source/reconstruction bugs in NNModule named_* calls (#89729)
Fixes https://github.com/pytorch/torchdynamo/issues/1931 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89729 Approved by: https://github.com/ezyang
This commit is contained in:
parent
447283752c
commit
d88b555577
|
|
@ -595,6 +595,57 @@ class ModuleAttributePrecedence(ModuleAttributePrecedenceBase):
|
|||
return self.activation(self.linear(self.initializer + x)) * self.scale
|
||||
|
||||
|
||||
class ModuleForwardHasGraphBreak(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer1 = BasicModule()
|
||||
self.layer2 = BasicModule()
|
||||
self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule())
|
||||
self.layer4 = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(10, 10),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(10, 10),
|
||||
torch.nn.ReLU(),
|
||||
]
|
||||
)
|
||||
self.layer5 = torch.nn.ModuleDict(
|
||||
{
|
||||
"0": torch.nn.Linear(10, 10),
|
||||
}
|
||||
)
|
||||
self.scale = torch.randn(1, 10)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
This is used to test if the results of functions like `named_parameters`
|
||||
can be reconstructed correctly after graph break.
|
||||
|
||||
https://github.com/pytorch/torchdynamo/issues/1931
|
||||
"""
|
||||
x = self.layer1(x)
|
||||
params1 = dict(self.named_parameters())
|
||||
params2 = list(self.parameters())
|
||||
buffers1 = dict(self.named_buffers())
|
||||
buffers2 = list(self.buffers())
|
||||
modules1 = dict(self.named_modules())
|
||||
modules2 = list(self.modules())
|
||||
torch._dynamo.graph_break()
|
||||
y = modules2
|
||||
y = modules1
|
||||
y = buffers2
|
||||
y = buffers1
|
||||
y = params2
|
||||
y = params1
|
||||
x = (
|
||||
self.layer2(x)
|
||||
+ y["layer3.1.linear1.weight"]
|
||||
+ y["layer4.2.weight"]
|
||||
+ y["layer5.0.weight"]
|
||||
)
|
||||
return x * self.scale
|
||||
|
||||
|
||||
def make_test(fn, expected_ops=None):
|
||||
def test_fn(self):
|
||||
return torch._dynamo.testing.standard_test(
|
||||
|
|
@ -646,6 +697,14 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
|||
test_module_name_string = make_test(ModuleNameString())
|
||||
test_module_attribute_precedence = make_test(ModuleAttributePrecedence())
|
||||
|
||||
def test_module_forward_has_graph_break(self):
|
||||
m = ModuleForwardHasGraphBreak()
|
||||
x = torch.rand([10, 10])
|
||||
ref = m(x)
|
||||
opt_m = torch._dynamo.optimize("eager")(m)
|
||||
res = opt_m(x)
|
||||
self.assertTrue(torch.allclose(ref, res))
|
||||
|
||||
def test_unsupportedmethod(self):
|
||||
m = UnsupportedMethodCall()
|
||||
i = torch.randn(10)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import re
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List
|
||||
|
|
@ -283,18 +282,15 @@ class NNModuleVariable(VariableTracker):
|
|||
bound_args = bound_args.arguments
|
||||
return {k: bound_args[k] for k in names}
|
||||
|
||||
def wrap_values(items, getsource=AttrSource):
|
||||
def wrap_values(items):
|
||||
result = []
|
||||
for name, submod in items:
|
||||
# layer.0.foo => layer[0].foo
|
||||
name = re.sub(r"[.]([0-9]+)([.]|$)", r"[\1]\2", name)
|
||||
src = NNModuleSource(getsource(self.source, name))
|
||||
result.append(
|
||||
tx.output.register_attr_or_module(
|
||||
submod,
|
||||
key,
|
||||
name,
|
||||
source=src,
|
||||
source=NNModuleSource(gen_source(self.source, name)),
|
||||
**options,
|
||||
)
|
||||
)
|
||||
|
|
@ -308,12 +304,21 @@ class NNModuleVariable(VariableTracker):
|
|||
obj,
|
||||
key,
|
||||
name,
|
||||
source=NNModuleSource(AttrSource(self.source, name)),
|
||||
source=NNModuleSource(gen_source(self.source, name)),
|
||||
**options,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def gen_source(source, name):
|
||||
name_split = name.split(".")
|
||||
if name_split[0] == "":
|
||||
return source
|
||||
while len(name_split) > 0:
|
||||
x = name_split.pop(0)
|
||||
source = AttrSource(source, x)
|
||||
return source
|
||||
|
||||
if name == "children":
|
||||
assert not (args or kwargs)
|
||||
return wrap_values(module.named_children())
|
||||
|
|
@ -344,7 +349,7 @@ class NNModuleVariable(VariableTracker):
|
|||
return wrap_values(module.named_parameters(**get_kwargs("recurse")))
|
||||
elif name == "values":
|
||||
assert not (args or kwargs)
|
||||
return wrap_values(module.items(), GetItemSource)
|
||||
return wrap_values(module.items())
|
||||
elif name == "items":
|
||||
assert not (args or kwargs)
|
||||
result = []
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user