mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Include nn.ParameterDict in dynamo __getitem__ (#99771)
Summary: Fix: #99735 Test Plan: Please see GitHub tests. Differential Revision: D45200616 Pull Request resolved: https://github.com/pytorch/pytorch/pull/99771 Approved by: https://github.com/Skylion007, https://github.com/anijain2305
This commit is contained in:
parent
ba167e6578
commit
db4aed6a03
|
|
@ -341,6 +341,37 @@ class ModuleDict(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class ParameterDict(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ParameterDict(
|
||||
{
|
||||
"0": torch.nn.Parameter(torch.randn(10, 10)),
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layers["0"].mm(x)
|
||||
return x
|
||||
|
||||
|
||||
class CustomGetItemParameterDict(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ParameterDict(
|
||||
{
|
||||
"0": torch.nn.Parameter(torch.randn(10, 10)),
|
||||
}
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> torch.nn.Module:
|
||||
return self.layers[key]
|
||||
|
||||
def forward(self, x):
|
||||
x = self["0"].mm(x)
|
||||
return x
|
||||
|
||||
|
||||
class CustomGetItemModuleDict(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -999,6 +1030,8 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
|
|||
test_modulelist = make_test(CustomGetItemModuleList())
|
||||
test_moduledict = make_test(ModuleDict())
|
||||
test_moduledict = make_test(CustomGetItemModuleDict())
|
||||
test_parameterdict = make_test(ParameterDict())
|
||||
test_parameterdict = make_test(CustomGetItemParameterDict())
|
||||
test_super1 = make_test(SuperModule())
|
||||
test_super2 = make_test(SuperModule2())
|
||||
test_super_class_method = make_test(SuperChildCallsClassMethod())
|
||||
|
|
|
|||
|
|
@ -536,6 +536,7 @@ class NNModuleVariable(VariableTracker):
|
|||
builtin_supported = (
|
||||
torch.nn.ModuleDict.__getitem__,
|
||||
torch.nn.ModuleList.__getitem__,
|
||||
torch.nn.ParameterDict.__getitem__,
|
||||
torch.nn.ParameterList.__getitem__,
|
||||
torch.nn.Sequential.__getitem__,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user