Include nn.ParameterDict in dynamo __getitem__ (#99771)

Summary:

Fix: #99735

Test Plan: Please see GitHub tests.

Differential Revision: D45200616

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99771
Approved by: https://github.com/Skylion007, https://github.com/anijain2305
This commit is contained in:
Danni Li 2023-07-11 08:19:01 +00:00 committed by PyTorch MergeBot
parent ba167e6578
commit db4aed6a03
2 changed files with 34 additions and 0 deletions

View File

@ -341,6 +341,37 @@ class ModuleDict(torch.nn.Module):
return x
class ParameterDict(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ParameterDict(
{
"0": torch.nn.Parameter(torch.randn(10, 10)),
}
)
def forward(self, x):
x = self.layers["0"].mm(x)
return x
class CustomGetItemParameterDict(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ParameterDict(
{
"0": torch.nn.Parameter(torch.randn(10, 10)),
}
)
def __getitem__(self, key: str) -> torch.nn.Module:
return self.layers[key]
def forward(self, x):
x = self["0"].mm(x)
return x
class CustomGetItemModuleDict(torch.nn.Module):
def __init__(self):
super().__init__()
@ -999,6 +1030,8 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
test_modulelist = make_test(CustomGetItemModuleList())
test_moduledict = make_test(ModuleDict())
test_moduledict = make_test(CustomGetItemModuleDict())
test_parameterdict = make_test(ParameterDict())
test_parameterdict = make_test(CustomGetItemParameterDict())
test_super1 = make_test(SuperModule())
test_super2 = make_test(SuperModule2())
test_super_class_method = make_test(SuperChildCallsClassMethod())

View File

@ -536,6 +536,7 @@ class NNModuleVariable(VariableTracker):
builtin_supported = (
torch.nn.ModuleDict.__getitem__,
torch.nn.ModuleList.__getitem__,
torch.nn.ParameterDict.__getitem__,
torch.nn.ParameterList.__getitem__,
torch.nn.Sequential.__getitem__,
)