Fix subclass unwrapping bug (#143945)

I noticed a small bug in tensor subclass unwrapping logic. cc @IvanKobzarev
It seems easier if we just implement it recursively so that it is easier to track the inner attrs to corresponding plain tensors and both aot_autograd and fake_tensor implement subclass unwrapping recursively.

Differential Revision: [D67693610](https://our.internmc.facebook.com/intern/diff/D67693610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143945
Approved by: https://github.com/IvanKobzarev
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2024-12-28 14:01:01 -08:00 committed by PyTorch MergeBot
parent 5c783bf410
commit d65a50ef34
3 changed files with 104 additions and 37 deletions

View File

@ -1659,11 +1659,11 @@ graph():
%add : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original0), kwargs = {})
%add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original1), kwargs = {})
%add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original2), kwargs = {})
%add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original3), kwargs = {})
%add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %add_4), kwargs = {})
%add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_2, %add_5), kwargs = {})
%add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %add_6), kwargs = {})
%add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, %add_2), kwargs = {})
%add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original2), kwargs = {})
%add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, %p_parametrizations_p2_original3), kwargs = {})
%add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_4, %add_5), kwargs = {})
%add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_3, %add_6), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%add_7,), kwargs = {})
%sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%sum_1,), kwargs = {})
%add_8 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %sum_2), kwargs = {})

View File

@ -6328,6 +6328,53 @@ metadata incorrectly.
out.sum().backward()
self.assertEqual(ref_x.grad, x.grad)
def test_subclass_parameters_torture_case(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.p1 = torch.nn.Parameter(torch.ones(3, 4))
self.p2 = torch.nn.Parameter(
TwoTensor(
TwoTensor(
torch.ones(3, 4),
TwoTensor(torch.randn(3, 4), torch.randn(3, 4)),
),
TwoTensor(
TwoTensor(torch.randn(3, 4), torch.randn(3, 4)),
TwoTensor(torch.ones(3, 4), torch.randn(3, 4)),
),
)
)
def forward(self, x):
return x + 2 * self.p1 + self.p2.a.b
m = M()
ref_x = torch.randn(3, 4)
ref_out = m(ref_x)
ref_out.sum().backward()
m.zero_grad()
from torch._functorch._aot_autograd.subclass_parametrization import (
unwrap_tensor_subclass_parameters,
)
unwrap_tensor_subclass_parameters(m)
ref_x2 = ref_x.detach().clone()
ref_out2 = m(ref_x2)
self.assertEqual(ref_out2, ref_out)
ref_out2.sum().backward()
self.assertEqual(ref_x2.grad, ref_x.grad)
m.zero_grad()
x = ref_x.detach().clone()
comp_fn = torch.compile(m, backend="aot_eager", fullgraph=True)
out = comp_fn(x)
self.assertEqual(ref_out, out)
out.sum().backward()
self.assertEqual(ref_x.grad, x.grad)
def test_rrelu_with_noise_mutation(self):
def fn_functional(x):
noise = torch.ones_like(x)

View File

@ -1,49 +1,69 @@
from typing import List, Tuple
import dataclasses
from typing import Any, Dict, List, Tuple
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
# This is technically very similar to SubclassCreatingMeta
# in aot_autograd, but we don't need all the stuff in there
# so just recreated a new dataclass.
@dataclasses.dataclass
class SubclassCreationMeta:
start_idx: int
num_tensors: int
class_type: Any
attrs: Dict[str, "SubclassCreationMeta"]
metadata: Any
class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def]
todo: List[torch.Tensor] = list(tensors)
for tp, meta, inner_tensors_attrs in reversed(self.rebuild_stack):
num_children: int = len(inner_tensors_attrs)
d = { # noqa: C416
a: b for a, b in zip(inner_tensors_attrs, todo[-num_children:])
}
todo = todo[:-num_children]
rebuilt = tp.__tensor_unflatten__(d, meta, None, None) # type: ignore[attr-defined]
todo.append(rebuilt)
assert len(todo) == 1
return todo[0]
def _unwrap_tensor_subclasses(subclass_meta, tensors, offset): # type: ignore[no-untyped-def]
if subclass_meta is None:
return tensors[offset], offset + 1
inner_tensors = {}
for attr, meta in subclass_meta.attrs.items():
built_tensor, offset = _unwrap_tensor_subclasses(meta, tensors, offset)
inner_tensors[attr] = built_tensor
rebuilt = subclass_meta.class_type.__tensor_unflatten__(
inner_tensors, subclass_meta.metadata, None, None
)
return rebuilt, offset
return _unwrap_tensor_subclasses(self.subclass_meta, todo, 0)[0]
def right_inverse(self, tensor: torch.Tensor) -> List[torch.Tensor]:
assert type(tensor) is not torch.Tensor
rebuild_stack = []
plain_tensors = []
todo = [tensor]
while todo:
obj = todo.pop()
inner_tensors_attrnames, metadata = obj.__tensor_flatten__() # type: ignore[attr-defined]
inner_tensors_attrnames_stack_order = []
subclasses_attrnames = []
for attr_name in inner_tensors_attrnames:
val = getattr(obj, attr_name)
if type(val) is torch.Tensor:
plain_tensors.append(val)
inner_tensors_attrnames_stack_order.append(attr_name)
else:
assert isinstance(val, torch.Tensor)
todo.append(val)
subclasses_attrnames.append(attr_name)
inner_tensors_attrnames_stack_order.extend(subclasses_attrnames)
rebuild_stack.append(
(type(obj), metadata, inner_tensors_attrnames_stack_order)
plain_tensors: List[torch.Tensor] = []
def _create_subclass_meta(tensor, idx, plain_tensor_container): # type: ignore[no-untyped-def]
if type(tensor) is torch.Tensor:
plain_tensor_container.append(tensor)
return None, idx + 1
inner_tensors_attrnames, metadata = tensor.__tensor_flatten__() # type: ignore[attr-defined]
new_idx = idx
attr_to_meta = {}
for attr in inner_tensors_attrnames:
val = getattr(tensor, attr)
subclass_meta, new_idx = _create_subclass_meta(
val, new_idx, plain_tensor_container
)
attr_to_meta[attr] = subclass_meta
return (
SubclassCreationMeta(
start_idx=idx,
num_tensors=new_idx - idx,
class_type=type(tensor),
attrs=attr_to_meta,
metadata=metadata,
),
new_idx,
)
self.rebuild_stack = rebuild_stack
self.subclass_meta = _create_subclass_meta(tensor, 0, plain_tensors)[0]
return plain_tensors