[subclass] Fix unwrap subclass parametrization for Nested subclasses (#142481)

@tugsbayasgalan found a bug for nested subclasses:

E.g. we have

TwoTensor(TwoTensor(t1, t2), t0).
After right_inverse we have:

rebuilt_stack == [(TwoTensor, meta, ["a", "b"]), (TwoTensor, meta, ["a", "b"])]
plain_tensors == [t0, t1, t2]
We will first put plain tensors, and only then the nested TwoTensor.

But when we unflatten:
todo = [t0, t1, t2]
we first create TwoTensor[t1, t2]
put it to todo [t0, TwoTensor[t1, t2]]
And as a result get

 TwoTensor(t0, TwoTensor(t1, t2))
which is swapping original a and b :)

So the fix should be different, we need to preserve the order of elements in the stack for plain/subclasses.
I will think about the fix.

Fix:

Keep order of inner_tensor_attr_names according them added to the stack. (first - plain tensor attributes, then subclass attributes)

Test:
```
python test/functorch/test_aotdispatch.py -k test_subclass_parameters
```

Differential Revision: [D67032477](https://our.internmc.facebook.com/intern/diff/D67032477)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142481
Approved by: https://github.com/tugsbayasgalan, https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev 2024-12-10 08:33:10 -08:00 committed by PyTorch MergeBot
parent 7e92b02e09
commit 1d3b0108a6
2 changed files with 24 additions and 10 deletions

View File

@ -6294,7 +6294,10 @@ metadata incorrectly.
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.p = torch.nn.Parameter( self.p = torch.nn.Parameter(
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4)) TwoTensor(
TwoTensor(torch.zeros(3, 4), torch.randn(3, 4)),
torch.ones(3, 4),
)
) )
def forward(self, x): def forward(self, x):
@ -6305,7 +6308,10 @@ metadata incorrectly.
super().__init__() super().__init__()
self.p1 = torch.nn.Parameter(torch.ones(3, 4)) self.p1 = torch.nn.Parameter(torch.ones(3, 4))
self.p2 = torch.nn.Parameter( self.p2 = torch.nn.Parameter(
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4)) TwoTensor(
torch.ones(3, 4),
TwoTensor(torch.randn(3, 4), torch.randn(3, 4)),
)
) )
self._m = _M() self._m = _M()

View File

@ -7,10 +7,12 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass
class UnwrapTensorSubclass(torch.nn.Module): class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def] def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def]
todo: List[torch.Tensor] = list(tensors) todo: List[torch.Tensor] = list(tensors)
for tp, meta, inner_tensors in reversed(self.rebuild_stack): for tp, meta, inner_tensors_attrs in reversed(self.rebuild_stack):
nb_tensor: int = len(inner_tensors) num_children: int = len(inner_tensors_attrs)
d = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])} # noqa: C416 d = { # noqa: C416
todo = todo[nb_tensor:] a: b for a, b in zip(inner_tensors_attrs, todo[-num_children:])
}
todo = todo[:-num_children]
rebuilt = tp.__tensor_unflatten__(d, meta, None, None) # type: ignore[attr-defined] rebuilt = tp.__tensor_unflatten__(d, meta, None, None) # type: ignore[attr-defined]
todo.append(rebuilt) todo.append(rebuilt)
@ -24,18 +26,24 @@ class UnwrapTensorSubclass(torch.nn.Module):
todo = [tensor] todo = [tensor]
while todo: while todo:
obj = todo.pop() obj = todo.pop()
inner_tensors, metadata = obj.__tensor_flatten__() # type: ignore[attr-defined] inner_tensors_attrnames, metadata = obj.__tensor_flatten__() # type: ignore[attr-defined]
rebuild_stack.append((type(obj), metadata, inner_tensors)) inner_tensors_attrnames_stack_order = []
for attr_name in inner_tensors: subclasses_attrnames = []
for attr_name in inner_tensors_attrnames:
val = getattr(obj, attr_name) val = getattr(obj, attr_name)
if type(val) is torch.Tensor: if type(val) is torch.Tensor:
plain_tensors.append(val) plain_tensors.append(val)
inner_tensors_attrnames_stack_order.append(attr_name)
else: else:
assert isinstance(val, torch.Tensor) assert isinstance(val, torch.Tensor)
todo.append(val) todo.append(val)
subclasses_attrnames.append(attr_name)
inner_tensors_attrnames_stack_order.extend(subclasses_attrnames)
rebuild_stack.append(
(type(obj), metadata, inner_tensors_attrnames_stack_order)
)
self.rebuild_stack = rebuild_stack self.rebuild_stack = rebuild_stack
return plain_tensors return plain_tensors