support input mutations on tangents in compile (#141131)

Fixes https://github.com/pytorch/pytorch/issues/141111. We previously supported mutations on saved activations that happened in the backward. This PR extends the support to tangents

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141131
Approved by: https://github.com/zou3519
This commit is contained in:
Brian Hirsh 2025-01-29 09:31:04 -08:00 committed by PyTorch MergeBot
parent 7077d0ac8c
commit 447a142de2
3 changed files with 65 additions and 7 deletions

View File

@ -1417,6 +1417,33 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
out.backward(retain_graph=True)
out.backward()
def test_autograd_function_tangent_mutation(self):
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone(), x.clone()
@staticmethod
def backward(ctx, grad1, grad2):
return grad1.copy_(grad2)
def f(x):
return Foo.apply(x)
x = torch.randn(4, requires_grad=True)
x_ref = x.clone().detach().requires_grad_()
out_ref = f(x_ref)
out = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
self.assertEqual(out_ref, out)
self.assertEqual(x_ref, x)
(out[0] + out[1]).sum().backward()
(out_ref[0] + out_ref[1]).sum().backward()
self.assertEqual(x_ref.grad, x.grad)
@torch._functorch.config.patch("donated_buffer", True)
def test_donated_buffer_with_retain_or_create_graph4(self):
# Gives non-empty bw_donated_idxs

View File

@ -2349,11 +2349,18 @@ def forward(self, primals_1, primals_2):
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=True),
]
inp_grad_ref = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=True),
]
f_compiled = aot_function(f, nop)
with self.assertRaisesRegex(
AssertionError, "input to the backward that was mutated during the backward"
):
f_compiled(*inp_grad)
out = f_compiled(*inp_grad)
out.mul(2).sum().backward()
out_ref = f(*inp_grad_ref)
out_ref.mul(2).sum().backward()
self.assertEqual(inp_grad[0].grad, inp_grad_ref[0].grad)
self.assertEqual(inp_grad[1].grad, inp_grad_ref[1].grad)
def test_backward_mutation_forward_inputs(self):
@torch.library.custom_op("_test::_clone", mutates_args={})

View File

@ -38,6 +38,9 @@ from torch.nn.utils import stateless
from .. import config
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
from .functional_utils import (
_check_if_mutation_can_be_in_graph,
are_all_mutations_hidden_from_autograd,
are_all_mutations_under_no_grad_or_inference_mode,
from_fun,
has_data_mutation,
has_metadata_mutation,
@ -479,9 +482,30 @@ def create_functionalized_fn(
):
assert not has_metadata_mutation(
f_inpt, before, check_only_storage_mutation=False
) and not has_data_mutation(
f_inpt
), "Found an input to the backward that was mutated during the backward pass. This is not supported"
), "Found an input to the backward that had metadata mutated during the backward pass. This is not supported"
if has_data_mutation(f_inpt):
can_be_in_graph = _check_if_mutation_can_be_in_graph(
keep_input_mutations=True,
mutates_data=True,
mutates_metadata=False,
mutations_hidden_from_autograd=are_all_mutations_hidden_from_autograd(
f_inpt
),
mutations_under_no_grad_or_inference_mode=are_all_mutations_under_no_grad_or_inference_mode(
f_inpt
),
mutates_storage_metadata=False,
mutation_inductor_storage_resize=was_inductor_storage_resized(
f_inpt
),
requires_grad=f_inpt.requires_grad,
)
assert (
can_be_in_graph
), "a backward input that had data mutated in an autograd-aware way. This is not supported"
# Perform the input mutation
with torch.fx.traceback.preserve_node_meta():
before.copy_(after)
if aot_config.keep_inference_input_mutations:
# Note: This is a bit annoying. There's a layering issue here, where: