mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix subclass unwrapping bug (#143945)
I noticed a small bug in tensor subclass unwrapping logic. cc @IvanKobzarev It seems easier if we just implement it recursively so that it is easier to track the inner attrs to corresponding plain tensors and both aot_autograd and fake_tensor implement subclass unwrapping recursively. Differential Revision: [D67693610](https://our.internmc.facebook.com/intern/diff/D67693610) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143945 Approved by: https://github.com/IvanKobzarev
This commit is contained in:
parent
5c783bf410
commit
d65a50ef34
|
|
@ -1659,11 +1659,11 @@ graph():
|
|||
%add : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
|
||||
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original0), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original1), kwargs = {})
|
||||
%add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original2), kwargs = {})
|
||||
%add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original3), kwargs = {})
|
||||
%add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %add_4), kwargs = {})
|
||||
%add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_2, %add_5), kwargs = {})
|
||||
%add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %add_6), kwargs = {})
|
||||
%add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %add_2), kwargs = {})
|
||||
%add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original2), kwargs = {})
|
||||
%add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original3), kwargs = {})
|
||||
%add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_4, %add_5), kwargs = {})
|
||||
%add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %add_6), kwargs = {})
|
||||
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%add_7,), kwargs = {})
|
||||
%sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%sum_1,), kwargs = {})
|
||||
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %sum_2), kwargs = {})
|
||||
|
|
|
|||
|
|
@ -6328,6 +6328,53 @@ metadata incorrectly.
|
|||
out.sum().backward()
|
||||
self.assertEqual(ref_x.grad, x.grad)
|
||||
|
||||
def test_subclass_parameters_torture_case(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p1 = torch.nn.Parameter(torch.ones(3, 4))
|
||||
self.p2 = torch.nn.Parameter(
|
||||
TwoTensor(
|
||||
TwoTensor(
|
||||
torch.ones(3, 4),
|
||||
TwoTensor(torch.randn(3, 4), torch.randn(3, 4)),
|
||||
),
|
||||
TwoTensor(
|
||||
TwoTensor(torch.randn(3, 4), torch.randn(3, 4)),
|
||||
TwoTensor(torch.ones(3, 4), torch.randn(3, 4)),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x + 2 * self.p1 + self.p2.a.b
|
||||
|
||||
m = M()
|
||||
ref_x = torch.randn(3, 4)
|
||||
ref_out = m(ref_x)
|
||||
ref_out.sum().backward()
|
||||
m.zero_grad()
|
||||
|
||||
from torch._functorch._aot_autograd.subclass_parametrization import (
|
||||
unwrap_tensor_subclass_parameters,
|
||||
)
|
||||
|
||||
unwrap_tensor_subclass_parameters(m)
|
||||
|
||||
ref_x2 = ref_x.detach().clone()
|
||||
ref_out2 = m(ref_x2)
|
||||
self.assertEqual(ref_out2, ref_out)
|
||||
ref_out2.sum().backward()
|
||||
self.assertEqual(ref_x2.grad, ref_x.grad)
|
||||
m.zero_grad()
|
||||
|
||||
x = ref_x.detach().clone()
|
||||
comp_fn = torch.compile(m, backend="aot_eager", fullgraph=True)
|
||||
out = comp_fn(x)
|
||||
self.assertEqual(ref_out, out)
|
||||
out.sum().backward()
|
||||
self.assertEqual(ref_x.grad, x.grad)
|
||||
|
||||
def test_rrelu_with_noise_mutation(self):
|
||||
def fn_functional(x):
|
||||
noise = torch.ones_like(x)
|
||||
|
|
|
|||
|
|
@ -1,49 +1,69 @@
|
|||
from typing import List, Tuple
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
|
||||
# This is technically very similar to SubclassCreatingMeta
|
||||
# in aot_autograd, but we don't need all the stuff in there
|
||||
# so just recreated a new dataclass.
|
||||
@dataclasses.dataclass
|
||||
class SubclassCreationMeta:
|
||||
start_idx: int
|
||||
num_tensors: int
|
||||
class_type: Any
|
||||
attrs: Dict[str, "SubclassCreationMeta"]
|
||||
metadata: Any
|
||||
|
||||
|
||||
class UnwrapTensorSubclass(torch.nn.Module):
|
||||
def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def]
|
||||
todo: List[torch.Tensor] = list(tensors)
|
||||
for tp, meta, inner_tensors_attrs in reversed(self.rebuild_stack):
|
||||
num_children: int = len(inner_tensors_attrs)
|
||||
d = { # noqa: C416
|
||||
a: b for a, b in zip(inner_tensors_attrs, todo[-num_children:])
|
||||
}
|
||||
todo = todo[:-num_children]
|
||||
rebuilt = tp.__tensor_unflatten__(d, meta, None, None) # type: ignore[attr-defined]
|
||||
todo.append(rebuilt)
|
||||
|
||||
assert len(todo) == 1
|
||||
return todo[0]
|
||||
def _unwrap_tensor_subclasses(subclass_meta, tensors, offset): # type: ignore[no-untyped-def]
|
||||
if subclass_meta is None:
|
||||
return tensors[offset], offset + 1
|
||||
inner_tensors = {}
|
||||
for attr, meta in subclass_meta.attrs.items():
|
||||
built_tensor, offset = _unwrap_tensor_subclasses(meta, tensors, offset)
|
||||
inner_tensors[attr] = built_tensor
|
||||
rebuilt = subclass_meta.class_type.__tensor_unflatten__(
|
||||
inner_tensors, subclass_meta.metadata, None, None
|
||||
)
|
||||
return rebuilt, offset
|
||||
|
||||
return _unwrap_tensor_subclasses(self.subclass_meta, todo, 0)[0]
|
||||
|
||||
def right_inverse(self, tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||
assert type(tensor) is not torch.Tensor
|
||||
rebuild_stack = []
|
||||
plain_tensors = []
|
||||
todo = [tensor]
|
||||
while todo:
|
||||
obj = todo.pop()
|
||||
inner_tensors_attrnames, metadata = obj.__tensor_flatten__() # type: ignore[attr-defined]
|
||||
inner_tensors_attrnames_stack_order = []
|
||||
subclasses_attrnames = []
|
||||
for attr_name in inner_tensors_attrnames:
|
||||
val = getattr(obj, attr_name)
|
||||
if type(val) is torch.Tensor:
|
||||
plain_tensors.append(val)
|
||||
inner_tensors_attrnames_stack_order.append(attr_name)
|
||||
else:
|
||||
assert isinstance(val, torch.Tensor)
|
||||
todo.append(val)
|
||||
subclasses_attrnames.append(attr_name)
|
||||
inner_tensors_attrnames_stack_order.extend(subclasses_attrnames)
|
||||
rebuild_stack.append(
|
||||
(type(obj), metadata, inner_tensors_attrnames_stack_order)
|
||||
plain_tensors: List[torch.Tensor] = []
|
||||
|
||||
def _create_subclass_meta(tensor, idx, plain_tensor_container): # type: ignore[no-untyped-def]
|
||||
if type(tensor) is torch.Tensor:
|
||||
plain_tensor_container.append(tensor)
|
||||
return None, idx + 1
|
||||
inner_tensors_attrnames, metadata = tensor.__tensor_flatten__() # type: ignore[attr-defined]
|
||||
new_idx = idx
|
||||
attr_to_meta = {}
|
||||
for attr in inner_tensors_attrnames:
|
||||
val = getattr(tensor, attr)
|
||||
subclass_meta, new_idx = _create_subclass_meta(
|
||||
val, new_idx, plain_tensor_container
|
||||
)
|
||||
attr_to_meta[attr] = subclass_meta
|
||||
return (
|
||||
SubclassCreationMeta(
|
||||
start_idx=idx,
|
||||
num_tensors=new_idx - idx,
|
||||
class_type=type(tensor),
|
||||
attrs=attr_to_meta,
|
||||
metadata=metadata,
|
||||
),
|
||||
new_idx,
|
||||
)
|
||||
|
||||
self.rebuild_stack = rebuild_stack
|
||||
self.subclass_meta = _create_subclass_meta(tensor, 0, plain_tensors)[0]
|
||||
return plain_tensors
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user