mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[aotd] Alias of intermediate unwrap TensorAlias (#147638)
Bug was reported by internal user.
AOTD classified outputs that are aliases of intermediates of the graph in different categories.
...
- output is alias of intermediate which base is already output
- output is alias of intermediate which base is not in output
If we look at the fn:
```
def fn(x):
ix = x + 1
a = ix.transpose(0, 1)
return a.detach(), a
```
output 0: detach view of alias a, where a is already output
output 1: alias of intermediate ix, then additional output ix will be added internally
output 0 base is TensorAlias(a) in this case, but could be Tensor.
Adding runtime unwrapping solves this problem.
Alternatively we should track base of a.detach() all the way to ix, in that case the base will be always a Tensor, not TensorAlias.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147638
Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
30db64bf51
commit
8594856651
|
|
@ -1087,6 +1087,46 @@ def forward(self, arg0_1, arg1_1):
|
|||
self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
|
||||
self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
|
||||
|
||||
@parametrize("backend", ["aot_eager", "inductor"])
|
||||
@parametrize("view_replay_for_aliased_outputs", [False, True])
|
||||
@parametrize("dynamic_shapes", [False, True])
|
||||
def test_alias_of_intermediate_detach(
|
||||
self, backend, view_replay_for_aliased_outputs, dynamic_shapes
|
||||
):
|
||||
with patch(
|
||||
"torch._functorch.config.view_replay_for_aliased_outputs",
|
||||
view_replay_for_aliased_outputs,
|
||||
):
|
||||
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
a = x.transpose(0, 1)
|
||||
return a.detach(), a
|
||||
|
||||
def inp_fn():
|
||||
t = torch.ones(3, 3, requires_grad=True)
|
||||
if dynamic_shapes:
|
||||
torch._dynamo.mark_dynamic(t, 0)
|
||||
torch._dynamo.mark_dynamic(t, 1)
|
||||
return t
|
||||
|
||||
x_ref = inp_fn()
|
||||
y_ref = fn(x_ref)
|
||||
|
||||
x = inp_fn()
|
||||
y = torch.compile(fn, backend=backend, fullgraph=True)(x)
|
||||
self.assertEqual(y_ref, y)
|
||||
y0, y1 = y
|
||||
self.assertFalse(y0.requires_grad)
|
||||
self.assertTrue(y1.requires_grad)
|
||||
# Check that detach and diff view points to the same intermediate tensor storage
|
||||
self.assertEqual(y0.data_ptr(), y1.data_ptr())
|
||||
self.assertTrue(y1._is_view())
|
||||
|
||||
sum(y_ref).sum().backward()
|
||||
sum(y).sum().backward()
|
||||
self.assertEqual(x_ref.grad, x.grad)
|
||||
|
||||
def test_input_mutation_storage_resize_up(self):
|
||||
def f(a):
|
||||
torch.ops.inductor.resize_storage_bytes_(a, 32)
|
||||
|
|
|
|||
|
|
@ -12858,17 +12858,18 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
|||
self.assertTrue(max_live_tensors == 3)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/100348
|
||||
def test_inductor_detach_view(self):
|
||||
@parametrize("backend", ["aot_eager", "inductor"])
|
||||
def test_inductor_detach_view(self, backend):
|
||||
def fn(x: torch.Tensor) -> torch.Tensor:
|
||||
a = x * 2
|
||||
return a, a.detach()
|
||||
|
||||
fn_opt = torch.compile(fn, backend="inductor")
|
||||
fn_opt = torch.compile(fn, backend=backend)
|
||||
inp = torch.ones(2, 2, requires_grad=True, device=GPU_TYPE)
|
||||
inp_ref = inp.detach().clone().requires_grad_(True)
|
||||
out_ref = fn(inp_ref)
|
||||
out = fn_opt(inp)
|
||||
out_ref[0].sum().backward()
|
||||
out = fn_opt(inp)
|
||||
out[0].sum().backward()
|
||||
self.assertEqual(inp.grad, inp_ref.grad)
|
||||
|
||||
|
|
|
|||
|
|
@ -201,6 +201,7 @@ class IsInputHandler:
|
|||
|
||||
class AliasOfIntermediateHandler:
|
||||
def __init__(self, info, runtime_metadata, trace_joint):
|
||||
self._unwrap_aliased_base_tensor = _identity
|
||||
if info.output_type in (
|
||||
OutputType.alias_of_intermediate,
|
||||
OutputType.alias_of_intermediate_save_as_output,
|
||||
|
|
@ -209,6 +210,8 @@ class AliasOfIntermediateHandler:
|
|||
self.base_idx = info.base_idx + num_user_outputs
|
||||
else:
|
||||
self.base_idx = info.base_idx
|
||||
if self.base_idx in runtime_metadata.aliased_out_indices:
|
||||
self._unwrap_aliased_base_tensor = _unwrap_tensoralias
|
||||
|
||||
self.unwrap_out = _unwrap_tensoralias if trace_joint else _identity
|
||||
self.requires_grad = info.requires_grad
|
||||
|
|
@ -218,7 +221,7 @@ class AliasOfIntermediateHandler:
|
|||
def __call__(self, orig_inputs, fw_outs, out):
|
||||
aliased_base_tensor = fw_outs[self.base_idx]
|
||||
return gen_alias_from_base(
|
||||
aliased_base_tensor,
|
||||
self._unwrap_aliased_base_tensor(aliased_base_tensor),
|
||||
self.unwrap_out(out),
|
||||
self.requires_grad,
|
||||
self.functional_tensor,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user