mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
support input mutations on tangents in compile (#141131)
Fixes https://github.com/pytorch/pytorch/issues/141111. We previously supported mutations on saved activations that happened in the backward. This PR extends the support to tangents Pull Request resolved: https://github.com/pytorch/pytorch/pull/141131 Approved by: https://github.com/zou3519
This commit is contained in:
parent
7077d0ac8c
commit
447a142de2
|
|
@ -1417,6 +1417,33 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
|||
out.backward(retain_graph=True)
|
||||
out.backward()
|
||||
|
||||
def test_autograd_function_tangent_mutation(self):
|
||||
class Foo(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.clone(), x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad1, grad2):
|
||||
return grad1.copy_(grad2)
|
||||
|
||||
def f(x):
|
||||
return Foo.apply(x)
|
||||
|
||||
x = torch.randn(4, requires_grad=True)
|
||||
x_ref = x.clone().detach().requires_grad_()
|
||||
|
||||
out_ref = f(x_ref)
|
||||
out = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
|
||||
|
||||
self.assertEqual(out_ref, out)
|
||||
self.assertEqual(x_ref, x)
|
||||
|
||||
(out[0] + out[1]).sum().backward()
|
||||
(out_ref[0] + out_ref[1]).sum().backward()
|
||||
|
||||
self.assertEqual(x_ref.grad, x.grad)
|
||||
|
||||
@torch._functorch.config.patch("donated_buffer", True)
|
||||
def test_donated_buffer_with_retain_or_create_graph4(self):
|
||||
# Gives non-empty bw_donated_idxs
|
||||
|
|
|
|||
|
|
@ -2349,11 +2349,18 @@ def forward(self, primals_1, primals_2):
|
|||
torch.ones(3, 3, requires_grad=True),
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
]
|
||||
inp_grad_ref = [
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
torch.ones(3, 3, requires_grad=True),
|
||||
]
|
||||
|
||||
f_compiled = aot_function(f, nop)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, "input to the backward that was mutated during the backward"
|
||||
):
|
||||
f_compiled(*inp_grad)
|
||||
out = f_compiled(*inp_grad)
|
||||
out.mul(2).sum().backward()
|
||||
out_ref = f(*inp_grad_ref)
|
||||
out_ref.mul(2).sum().backward()
|
||||
self.assertEqual(inp_grad[0].grad, inp_grad_ref[0].grad)
|
||||
self.assertEqual(inp_grad[1].grad, inp_grad_ref[1].grad)
|
||||
|
||||
def test_backward_mutation_forward_inputs(self):
|
||||
@torch.library.custom_op("_test::_clone", mutates_args={})
|
||||
|
|
|
|||
|
|
@ -38,6 +38,9 @@ from torch.nn.utils import stateless
|
|||
from .. import config
|
||||
from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
|
||||
from .functional_utils import (
|
||||
_check_if_mutation_can_be_in_graph,
|
||||
are_all_mutations_hidden_from_autograd,
|
||||
are_all_mutations_under_no_grad_or_inference_mode,
|
||||
from_fun,
|
||||
has_data_mutation,
|
||||
has_metadata_mutation,
|
||||
|
|
@ -479,9 +482,30 @@ def create_functionalized_fn(
|
|||
):
|
||||
assert not has_metadata_mutation(
|
||||
f_inpt, before, check_only_storage_mutation=False
|
||||
) and not has_data_mutation(
|
||||
), "Found an input to the backward that had metadata mutated during the backward pass. This is not supported"
|
||||
if has_data_mutation(f_inpt):
|
||||
can_be_in_graph = _check_if_mutation_can_be_in_graph(
|
||||
keep_input_mutations=True,
|
||||
mutates_data=True,
|
||||
mutates_metadata=False,
|
||||
mutations_hidden_from_autograd=are_all_mutations_hidden_from_autograd(
|
||||
f_inpt
|
||||
), "Found an input to the backward that was mutated during the backward pass. This is not supported"
|
||||
),
|
||||
mutations_under_no_grad_or_inference_mode=are_all_mutations_under_no_grad_or_inference_mode(
|
||||
f_inpt
|
||||
),
|
||||
mutates_storage_metadata=False,
|
||||
mutation_inductor_storage_resize=was_inductor_storage_resized(
|
||||
f_inpt
|
||||
),
|
||||
requires_grad=f_inpt.requires_grad,
|
||||
)
|
||||
assert (
|
||||
can_be_in_graph
|
||||
), "a backward input that had data mutated in an autograd-aware way. This is not supported"
|
||||
# Perform the input mutation
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
before.copy_(after)
|
||||
|
||||
if aot_config.keep_inference_input_mutations:
|
||||
# Note: This is a bit annoying. There's a layering issue here, where:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user