support input mutations on tangents in compile (#141131)

Fixes https://github.com/pytorch/pytorch/issues/141111. We previously supported mutations on saved activations that happened in the backward. This PR extends the support to tangents

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141131
Approved by: https://github.com/zou3519
This commit is contained in:
Brian Hirsh 2025-01-29 09:31:04 -08:00 committed by PyTorch MergeBot
parent 7077d0ac8c
commit 447a142de2
3 changed files with 65 additions and 7 deletions

View File

@ -1417,6 +1417,33 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
out.backward(retain_graph=True)
out.backward()
def test_autograd_function_tangent_mutation(self):
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone(), x.clone()
@staticmethod
def backward(ctx, grad1, grad2):
return grad1.copy_(grad2)
def f(x):
return Foo.apply(x)
x = torch.randn(4, requires_grad=True)
x_ref = x.clone().detach().requires_grad_()
out_ref = f(x_ref)
out = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
self.assertEqual(out_ref, out)
self.assertEqual(x_ref, x)
(out[0] + out[1]).sum().backward()
(out_ref[0] + out_ref[1]).sum().backward()
self.assertEqual(x_ref.grad, x.grad)
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer_with_retain_or_create_graph4(self):
# Gives non-empty bw_donated_idxs

View File

@ -2349,11 +2349,18 @@ def forward(self, primals_1, primals_2):
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=True),
]
inp_grad_ref = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=True),
]
f_compiled = aot_function(f, nop)
with self.assertRaisesRegex(
AssertionError, "input to the backward that was mutated during the backward"
):
f_compiled(*inp_grad)
out = f_compiled(*inp_grad)
out.mul(2).sum().backward()
out_ref = f(*inp_grad_ref)
out_ref.mul(2).sum().backward()
self.assertEqual(inp_grad[0].grad, inp_grad_ref[0].grad)
self.assertEqual(inp_grad[1].grad, inp_grad_ref[1].grad)
def test_backward_mutation_forward_inputs(self):
@torch.library.custom_op("_test::_clone", mutates_args={})

View File

@ -38,6 +38,9 @@ from torch.nn.utils import stateless
from .. import config
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
from .functional_utils import (
_check_if_mutation_can_be_in_graph,
are_all_mutations_hidden_from_autograd,
are_all_mutations_under_no_grad_or_inference_mode,
from_fun,
has_data_mutation,
has_metadata_mutation,
@ -479,9 +482,30 @@ def create_functionalized_fn(
):
assert not has_metadata_mutation(
f_inpt, before, check_only_storage_mutation=False
) and not has_data_mutation(
), "Found an input to the backward that had metadata mutated during the backward pass. This is not supported"
if has_data_mutation(f_inpt):
can_be_in_graph = _check_if_mutation_can_be_in_graph(
keep_input_mutations=True,
mutates_data=True,
mutates_metadata=False,
mutations_hidden_from_autograd=are_all_mutations_hidden_from_autograd(
f_inpt
), "Found an input to the backward that was mutated during the backward pass. This is not supported"
),
mutations_under_no_grad_or_inference_mode=are_all_mutations_under_no_grad_or_inference_mode(
f_inpt
),
mutates_storage_metadata=False,
mutation_inductor_storage_resize=was_inductor_storage_resized(
f_inpt
),
requires_grad=f_inpt.requires_grad,
)
assert (
can_be_in_graph
), "a backward input that had data mutated in an autograd-aware way. This is not supported"
# Perform the input mutation
with torch.fx.traceback.preserve_node_meta():
before.copy_(after)
if aot_config.keep_inference_input_mutations:
# Note: This is a bit annoying. There's a layering issue here, where: