mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[subclass] Fix unwrap subclass parametrization for Nested subclasses (#142481)
@tugsbayasgalan found a bug for nested subclasses: E.g. we have TwoTensor(TwoTensor(t1, t2), t0). After right_inverse we have: rebuilt_stack == [(TwoTensor, meta, ["a", "b"]), (TwoTensor, meta, ["a", "b"])] plain_tensors == [t0, t1, t2] We will first put plain tensors, and only then the nested TwoTensor. But when we unflatten: todo = [t0, t1, t2] we first create TwoTensor[t1, t2] put it to todo [t0, TwoTensor[t1, t2]] And as a result get TwoTensor(t0, TwoTensor(t1, t2)) which is swapping original a and b :) So the fix should be different, we need to preserve the order of elements in the stack for plain/subclasses. I will think about the fix. Fix: Keep order of inner_tensor_attr_names according them added to the stack. (first - plain tensor attributes, then subclass attributes) Test: ``` python test/functorch/test_aotdispatch.py -k test_subclass_parameters ``` Differential Revision: [D67032477](https://our.internmc.facebook.com/intern/diff/D67032477) Pull Request resolved: https://github.com/pytorch/pytorch/pull/142481 Approved by: https://github.com/tugsbayasgalan, https://github.com/bdhirsh
This commit is contained in:
parent
7e92b02e09
commit
1d3b0108a6
|
|
@ -6294,7 +6294,10 @@ metadata incorrectly.
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p = torch.nn.Parameter(
|
self.p = torch.nn.Parameter(
|
||||||
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4))
|
TwoTensor(
|
||||||
|
TwoTensor(torch.zeros(3, 4), torch.randn(3, 4)),
|
||||||
|
torch.ones(3, 4),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
@ -6305,7 +6308,10 @@ metadata incorrectly.
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p1 = torch.nn.Parameter(torch.ones(3, 4))
|
self.p1 = torch.nn.Parameter(torch.ones(3, 4))
|
||||||
self.p2 = torch.nn.Parameter(
|
self.p2 = torch.nn.Parameter(
|
||||||
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4))
|
TwoTensor(
|
||||||
|
torch.ones(3, 4),
|
||||||
|
TwoTensor(torch.randn(3, 4), torch.randn(3, 4)),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self._m = _M()
|
self._m = _M()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,12 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||||
class UnwrapTensorSubclass(torch.nn.Module):
|
class UnwrapTensorSubclass(torch.nn.Module):
|
||||||
def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def]
|
def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def]
|
||||||
todo: List[torch.Tensor] = list(tensors)
|
todo: List[torch.Tensor] = list(tensors)
|
||||||
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
|
for tp, meta, inner_tensors_attrs in reversed(self.rebuild_stack):
|
||||||
nb_tensor: int = len(inner_tensors)
|
num_children: int = len(inner_tensors_attrs)
|
||||||
d = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])} # noqa: C416
|
d = { # noqa: C416
|
||||||
todo = todo[nb_tensor:]
|
a: b for a, b in zip(inner_tensors_attrs, todo[-num_children:])
|
||||||
|
}
|
||||||
|
todo = todo[:-num_children]
|
||||||
rebuilt = tp.__tensor_unflatten__(d, meta, None, None) # type: ignore[attr-defined]
|
rebuilt = tp.__tensor_unflatten__(d, meta, None, None) # type: ignore[attr-defined]
|
||||||
todo.append(rebuilt)
|
todo.append(rebuilt)
|
||||||
|
|
||||||
|
|
@ -24,18 +26,24 @@ class UnwrapTensorSubclass(torch.nn.Module):
|
||||||
todo = [tensor]
|
todo = [tensor]
|
||||||
while todo:
|
while todo:
|
||||||
obj = todo.pop()
|
obj = todo.pop()
|
||||||
inner_tensors, metadata = obj.__tensor_flatten__() # type: ignore[attr-defined]
|
inner_tensors_attrnames, metadata = obj.__tensor_flatten__() # type: ignore[attr-defined]
|
||||||
rebuild_stack.append((type(obj), metadata, inner_tensors))
|
inner_tensors_attrnames_stack_order = []
|
||||||
for attr_name in inner_tensors:
|
subclasses_attrnames = []
|
||||||
|
for attr_name in inner_tensors_attrnames:
|
||||||
val = getattr(obj, attr_name)
|
val = getattr(obj, attr_name)
|
||||||
if type(val) is torch.Tensor:
|
if type(val) is torch.Tensor:
|
||||||
plain_tensors.append(val)
|
plain_tensors.append(val)
|
||||||
|
inner_tensors_attrnames_stack_order.append(attr_name)
|
||||||
else:
|
else:
|
||||||
assert isinstance(val, torch.Tensor)
|
assert isinstance(val, torch.Tensor)
|
||||||
todo.append(val)
|
todo.append(val)
|
||||||
|
subclasses_attrnames.append(attr_name)
|
||||||
|
inner_tensors_attrnames_stack_order.extend(subclasses_attrnames)
|
||||||
|
rebuild_stack.append(
|
||||||
|
(type(obj), metadata, inner_tensors_attrnames_stack_order)
|
||||||
|
)
|
||||||
|
|
||||||
self.rebuild_stack = rebuild_stack
|
self.rebuild_stack = rebuild_stack
|
||||||
|
|
||||||
return plain_tensors
|
return plain_tensors
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user