[subclass] Fix unwrap_subclass_parameters parametrization (#142155)

Parametrization can not be registered for non-direct child parameters of the module.
We have to iterate through all submodules and register parametrization at every level.

Original testcase did not test the nested modules case - adding submodule to the test.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_subclass_parameters
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142155
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev 2024-12-05 10:39:54 -08:00 committed by PyTorch MergeBot
parent 2bfc600644
commit efab8c433f
2 changed files with 22 additions and 5 deletions

View File

@ -6291,6 +6291,16 @@ metadata incorrectly.
torch.compile(fn_, backend="inductor", fullgraph=True)(x)
def test_subclass_parameters(self):
class _M(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Parameter(
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4))
)
def forward(self, x):
return x + self.p
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@ -6298,9 +6308,10 @@ metadata incorrectly.
self.p2 = torch.nn.Parameter(
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4))
)
self._m = _M()
def forward(self, x):
return x + 2 * self.p1 + self.p2
return self._m(x) + x + 2 * self.p1 + self.p2
m = M()
ref_x = torch.randn(3, 4)

View File

@ -39,7 +39,7 @@ class UnwrapTensorSubclass(torch.nn.Module):
return plain_tensors
def unwrap_tensor_subclass_parameters(model: torch.nn.Module) -> torch.nn.Module:
def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Module:
"""
Model transformation that replaces all the parameters that are subclasses to plain tensors.
This reduces runtime overhead of flattening/unflattening the parameters.
@ -51,10 +51,16 @@ def unwrap_tensor_subclass_parameters(model: torch.nn.Module) -> torch.nn.Module
becomes: {"parametrizations.p2.original0": torch.Tensor, "parametrizations.p2.original1": torch.Tensor}
"""
name_param: List[Tuple[str, torch.nn.Parameter]] = list(model.named_parameters())
name_param: List[Tuple[str, torch.nn.Parameter]] = list(
module.named_parameters(recurse=False)
)
for name, param in name_param:
if is_traceable_wrapper_subclass(param):
torch.nn.utils.parametrize.register_parametrization(
model, name, UnwrapTensorSubclass()
module, name, UnwrapTensorSubclass()
)
return model
for name, child in module.named_children():
unwrap_tensor_subclass_parameters(child)
return module