_aot_export_function: allow keeping input mutations in the graph (#157730)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157730
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh 2025-07-07 13:15:42 -07:00 committed by PyTorch MergeBot
parent ed03492238
commit 54a7e5b598
2 changed files with 34 additions and 1 deletions

View File

@ -51,6 +51,7 @@ from torch._dynamo.testing import normalize_gm
from torch._dynamo.utils import counters
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._functorch.aot_autograd import (
_aot_export_function,
aot_export_joint_simple,
aot_export_module,
SerializableAOTDispatchCompiler,
@ -5389,6 +5390,37 @@ def forward(self):
return (full_1,)""", # noqa: B950
)
def test_aot_export_input_mutation(self):
def f(x, buf):
buf.add_(1)
return buf * x
x = torch.randn(2, requires_grad=True)
buf = torch.zeros(2, requires_grad=False)
gm, _, _, _ = _aot_export_function(
f,
(x, buf),
decompositions={},
num_params_buffers=1,
no_tangents=False,
pre_dispatch=False,
dynamic_shapes=False,
keep_input_mutations=True,
kwargs={},
)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, primals, tangents):
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
add = torch.ops.aten.add.Tensor(primals_2, 1)
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = None
copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = copy_ = None
return pytree.tree_unflatten([mul, mul_1, None], self._out_spec)""",
)
class TestPartitioning(AOTTestCase):
@unittest.skipIf(not USE_NETWORKX, "networkx not available")

View File

@ -1572,6 +1572,7 @@ def _aot_export_function(
pre_dispatch: bool = False,
# If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong.
dynamic_shapes: Optional[bool] = None,
keep_input_mutations: bool = False,
kwargs=None,
) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
kwargs = kwargs or {}
@ -1610,7 +1611,7 @@ def _aot_export_function(
# For now there's no use case involving keeping input mutations in the graph
# (which we can only do in the inference case anyway).
# We can add this later if we need to.
keep_inference_input_mutations=False,
keep_inference_input_mutations=keep_input_mutations,
dynamic_shapes=dynamic_shapes,
aot_autograd_arg_pos_to_source=None,
is_export=True,