From 54a7e5b5983d237b324b50703bcb0919a6c4c296 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 7 Jul 2025 13:15:42 -0700 Subject: [PATCH] _aot_export_function: allow keeping input mutations in the graph (#157730) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157730 Approved by: https://github.com/ezyang --- test/functorch/test_aotdispatch.py | 32 ++++++++++++++++++++++++++++++ torch/_functorch/aot_autograd.py | 3 ++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index cf6ee336b86..f1d1c92d52f 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -51,6 +51,7 @@ from torch._dynamo.testing import normalize_gm from torch._dynamo.utils import counters from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._functorch.aot_autograd import ( + _aot_export_function, aot_export_joint_simple, aot_export_module, SerializableAOTDispatchCompiler, @@ -5389,6 +5390,37 @@ def forward(self): return (full_1,)""", # noqa: B950 ) + def test_aot_export_input_mutation(self): + def f(x, buf): + buf.add_(1) + return buf * x + + x = torch.randn(2, requires_grad=True) + buf = torch.zeros(2, requires_grad=False) + + gm, _, _, _ = _aot_export_function( + f, + (x, buf), + decompositions={}, + num_params_buffers=1, + no_tangents=False, + pre_dispatch=False, + dynamic_shapes=False, + keep_input_mutations=True, + kwargs={}, + ) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, primals, tangents): + primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) + add = torch.ops.aten.add.Tensor(primals_2, 1) + mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = None + copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = copy_ = None + return pytree.tree_unflatten([mul, mul_1, None], self._out_spec)""", + ) + class TestPartitioning(AOTTestCase): @unittest.skipIf(not USE_NETWORKX, "networkx not available") diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 96ef59bfebc..7fe748e2089 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1572,6 +1572,7 @@ def _aot_export_function( pre_dispatch: bool = False, # If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong. dynamic_shapes: Optional[bool] = None, + keep_input_mutations: bool = False, kwargs=None, ) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]: kwargs = kwargs or {} @@ -1610,7 +1611,7 @@ def _aot_export_function( # For now there's no use case involving keeping input mutations in the graph # (which we can only do in the inference case anyway). # We can add this later if we need to. - keep_inference_input_mutations=False, + keep_inference_input_mutations=keep_input_mutations, dynamic_shapes=dynamic_shapes, aot_autograd_arg_pos_to_source=None, is_export=True,