mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[subclass] Fix unwrap subclass parametrization for Nested subclasses (#142481)
@tugsbayasgalan found a bug for nested subclasses: E.g. we have TwoTensor(TwoTensor(t1, t2), t0). After right_inverse we have: rebuilt_stack == [(TwoTensor, meta, ["a", "b"]), (TwoTensor, meta, ["a", "b"])] plain_tensors == [t0, t1, t2] We will first put plain tensors, and only then the nested TwoTensor. But when we unflatten: todo = [t0, t1, t2] we first create TwoTensor[t1, t2] put it to todo [t0, TwoTensor[t1, t2]] And as a result get TwoTensor(t0, TwoTensor(t1, t2)) which is swapping original a and b :) So the fix should be different, we need to preserve the order of elements in the stack for plain/subclasses. I will think about the fix. Fix: Keep order of inner_tensor_attr_names according them added to the stack. (first - plain tensor attributes, then subclass attributes) Test: ``` python test/functorch/test_aotdispatch.py -k test_subclass_parameters ``` Differential Revision: [D67032477](https://our.internmc.facebook.com/intern/diff/D67032477) Pull Request resolved: https://github.com/pytorch/pytorch/pull/142481 Approved by: https://github.com/tugsbayasgalan, https://github.com/bdhirsh
This commit is contained in:
parent
7e92b02e09
commit
1d3b0108a6
|
|
@ -6294,7 +6294,10 @@ metadata incorrectly.
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p = torch.nn.Parameter(
|
||||
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4))
|
||||
TwoTensor(
|
||||
TwoTensor(torch.zeros(3, 4), torch.randn(3, 4)),
|
||||
torch.ones(3, 4),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -6305,7 +6308,10 @@ metadata incorrectly.
|
|||
super().__init__()
|
||||
self.p1 = torch.nn.Parameter(torch.ones(3, 4))
|
||||
self.p2 = torch.nn.Parameter(
|
||||
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4))
|
||||
TwoTensor(
|
||||
torch.ones(3, 4),
|
||||
TwoTensor(torch.randn(3, 4), torch.randn(3, 4)),
|
||||
)
|
||||
)
|
||||
self._m = _M()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|||
class UnwrapTensorSubclass(torch.nn.Module):
|
||||
def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def]
|
||||
todo: List[torch.Tensor] = list(tensors)
|
||||
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
|
||||
nb_tensor: int = len(inner_tensors)
|
||||
d = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])} # noqa: C416
|
||||
todo = todo[nb_tensor:]
|
||||
for tp, meta, inner_tensors_attrs in reversed(self.rebuild_stack):
|
||||
num_children: int = len(inner_tensors_attrs)
|
||||
d = { # noqa: C416
|
||||
a: b for a, b in zip(inner_tensors_attrs, todo[-num_children:])
|
||||
}
|
||||
todo = todo[:-num_children]
|
||||
rebuilt = tp.__tensor_unflatten__(d, meta, None, None) # type: ignore[attr-defined]
|
||||
todo.append(rebuilt)
|
||||
|
||||
|
|
@ -24,18 +26,24 @@ class UnwrapTensorSubclass(torch.nn.Module):
|
|||
todo = [tensor]
|
||||
while todo:
|
||||
obj = todo.pop()
|
||||
inner_tensors, metadata = obj.__tensor_flatten__() # type: ignore[attr-defined]
|
||||
rebuild_stack.append((type(obj), metadata, inner_tensors))
|
||||
for attr_name in inner_tensors:
|
||||
inner_tensors_attrnames, metadata = obj.__tensor_flatten__() # type: ignore[attr-defined]
|
||||
inner_tensors_attrnames_stack_order = []
|
||||
subclasses_attrnames = []
|
||||
for attr_name in inner_tensors_attrnames:
|
||||
val = getattr(obj, attr_name)
|
||||
if type(val) is torch.Tensor:
|
||||
plain_tensors.append(val)
|
||||
inner_tensors_attrnames_stack_order.append(attr_name)
|
||||
else:
|
||||
assert isinstance(val, torch.Tensor)
|
||||
todo.append(val)
|
||||
subclasses_attrnames.append(attr_name)
|
||||
inner_tensors_attrnames_stack_order.extend(subclasses_attrnames)
|
||||
rebuild_stack.append(
|
||||
(type(obj), metadata, inner_tensors_attrnames_stack_order)
|
||||
)
|
||||
|
||||
self.rebuild_stack = rebuild_stack
|
||||
|
||||
return plain_tensors
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user