[aotd] Alias of intermediate unwrap TensorAlias (#147638)

Bug was reported by internal user.

AOTD classified outputs that are aliases of intermediates of the graph in different categories.

...
- output is alias of intermediate which base is already output
- output is alias of intermediate which base is not in output

If we look at the fn:
```
def fn(x):
    ix = x + 1
    a = ix.transpose(0, 1)
    return a.detach(), a
```

output 0: detach view of alias a, where a is already output
output 1: alias of intermediate ix, then additional output ix will be added internally

output 0 base is TensorAlias(a) in this case, but could be Tensor.
Adding runtime unwrapping solves this problem.

Alternatively we should track base of a.detach() all the way to ix, in that case the base will be always a Tensor, not TensorAlias.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147638
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev 2025-02-25 09:08:41 -08:00 committed by PyTorch MergeBot
parent 30db64bf51
commit 8594856651
3 changed files with 48 additions and 4 deletions

View File

@ -1087,6 +1087,46 @@ def forward(self, arg0_1, arg1_1):
self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
@parametrize("backend", ["aot_eager", "inductor"])
@parametrize("view_replay_for_aliased_outputs", [False, True])
@parametrize("dynamic_shapes", [False, True])
def test_alias_of_intermediate_detach(
self, backend, view_replay_for_aliased_outputs, dynamic_shapes
):
with patch(
"torch._functorch.config.view_replay_for_aliased_outputs",
view_replay_for_aliased_outputs,
):
def fn(x):
x = x + 1
a = x.transpose(0, 1)
return a.detach(), a
def inp_fn():
t = torch.ones(3, 3, requires_grad=True)
if dynamic_shapes:
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(t, 1)
return t
x_ref = inp_fn()
y_ref = fn(x_ref)
x = inp_fn()
y = torch.compile(fn, backend=backend, fullgraph=True)(x)
self.assertEqual(y_ref, y)
y0, y1 = y
self.assertFalse(y0.requires_grad)
self.assertTrue(y1.requires_grad)
# Check that detach and diff view points to the same intermediate tensor storage
self.assertEqual(y0.data_ptr(), y1.data_ptr())
self.assertTrue(y1._is_view())
sum(y_ref).sum().backward()
sum(y).sum().backward()
self.assertEqual(x_ref.grad, x.grad)
def test_input_mutation_storage_resize_up(self):
def f(a):
torch.ops.inductor.resize_storage_bytes_(a, 32)

View File

@ -12858,17 +12858,18 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertTrue(max_live_tensors == 3)
# See https://github.com/pytorch/pytorch/issues/100348
def test_inductor_detach_view(self):
@parametrize("backend", ["aot_eager", "inductor"])
def test_inductor_detach_view(self, backend):
def fn(x: torch.Tensor) -> torch.Tensor:
a = x * 2
return a, a.detach()
fn_opt = torch.compile(fn, backend="inductor")
fn_opt = torch.compile(fn, backend=backend)
inp = torch.ones(2, 2, requires_grad=True, device=GPU_TYPE)
inp_ref = inp.detach().clone().requires_grad_(True)
out_ref = fn(inp_ref)
out = fn_opt(inp)
out_ref[0].sum().backward()
out = fn_opt(inp)
out[0].sum().backward()
self.assertEqual(inp.grad, inp_ref.grad)

View File

@ -201,6 +201,7 @@ class IsInputHandler:
class AliasOfIntermediateHandler:
def __init__(self, info, runtime_metadata, trace_joint):
self._unwrap_aliased_base_tensor = _identity
if info.output_type in (
OutputType.alias_of_intermediate,
OutputType.alias_of_intermediate_save_as_output,
@ -209,6 +210,8 @@ class AliasOfIntermediateHandler:
self.base_idx = info.base_idx + num_user_outputs
else:
self.base_idx = info.base_idx
if self.base_idx in runtime_metadata.aliased_out_indices:
self._unwrap_aliased_base_tensor = _unwrap_tensoralias
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
self.requires_grad = info.requires_grad
@ -218,7 +221,7 @@ class AliasOfIntermediateHandler:
def __call__(self, orig_inputs, fw_outs, out):
aliased_base_tensor = fw_outs[self.base_idx]
return gen_alias_from_base(
aliased_base_tensor,
self._unwrap_aliased_base_tensor(aliased_base_tensor),
self.unwrap_out(out),
self.requires_grad,
self.functional_tensor,