mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[subclass] Fix unwrap_subclass_parameters parametrization (#142155)
Parametrization can not be registered for non-direct child parameters of the module. We have to iterate through all submodules and register parametrization at every level. Original testcase did not test the nested modules case - adding submodule to the test. Testing: ``` python test/functorch/test_aotdispatch.py -k test_subclass_parameters ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/142155 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
2bfc600644
commit
efab8c433f
|
|
@ -6291,6 +6291,16 @@ metadata incorrectly.
|
|||
torch.compile(fn_, backend="inductor", fullgraph=True)(x)
|
||||
|
||||
def test_subclass_parameters(self):
|
||||
class _M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p = torch.nn.Parameter(
|
||||
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.p
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -6298,9 +6308,10 @@ metadata incorrectly.
|
|||
self.p2 = torch.nn.Parameter(
|
||||
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4))
|
||||
)
|
||||
self._m = _M()
|
||||
|
||||
def forward(self, x):
|
||||
return x + 2 * self.p1 + self.p2
|
||||
return self._m(x) + x + 2 * self.p1 + self.p2
|
||||
|
||||
m = M()
|
||||
ref_x = torch.randn(3, 4)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class UnwrapTensorSubclass(torch.nn.Module):
|
|||
return plain_tensors
|
||||
|
||||
|
||||
def unwrap_tensor_subclass_parameters(model: torch.nn.Module) -> torch.nn.Module:
|
||||
def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Module:
|
||||
"""
|
||||
Model transformation that replaces all the parameters that are subclasses to plain tensors.
|
||||
This reduces runtime overhead of flattening/unflattening the parameters.
|
||||
|
|
@ -51,10 +51,16 @@ def unwrap_tensor_subclass_parameters(model: torch.nn.Module) -> torch.nn.Module
|
|||
becomes: {"parametrizations.p2.original0": torch.Tensor, "parametrizations.p2.original1": torch.Tensor}
|
||||
|
||||
"""
|
||||
name_param: List[Tuple[str, torch.nn.Parameter]] = list(model.named_parameters())
|
||||
name_param: List[Tuple[str, torch.nn.Parameter]] = list(
|
||||
module.named_parameters(recurse=False)
|
||||
)
|
||||
for name, param in name_param:
|
||||
if is_traceable_wrapper_subclass(param):
|
||||
torch.nn.utils.parametrize.register_parametrization(
|
||||
model, name, UnwrapTensorSubclass()
|
||||
module, name, UnwrapTensorSubclass()
|
||||
)
|
||||
return model
|
||||
|
||||
for name, child in module.named_children():
|
||||
unwrap_tensor_subclass_parameters(child)
|
||||
|
||||
return module
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user