mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
_aot_export_function: allow keeping input mutations in the graph (#157730)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157730 Approved by: https://github.com/ezyang
This commit is contained in:
parent
ed03492238
commit
54a7e5b598
|
|
@ -51,6 +51,7 @@ from torch._dynamo.testing import normalize_gm
|
|||
from torch._dynamo.utils import counters
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._functorch.aot_autograd import (
|
||||
_aot_export_function,
|
||||
aot_export_joint_simple,
|
||||
aot_export_module,
|
||||
SerializableAOTDispatchCompiler,
|
||||
|
|
@ -5389,6 +5390,37 @@ def forward(self):
|
|||
return (full_1,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_aot_export_input_mutation(self):
|
||||
def f(x, buf):
|
||||
buf.add_(1)
|
||||
return buf * x
|
||||
|
||||
x = torch.randn(2, requires_grad=True)
|
||||
buf = torch.zeros(2, requires_grad=False)
|
||||
|
||||
gm, _, _, _ = _aot_export_function(
|
||||
f,
|
||||
(x, buf),
|
||||
decompositions={},
|
||||
num_params_buffers=1,
|
||||
no_tangents=False,
|
||||
pre_dispatch=False,
|
||||
dynamic_shapes=False,
|
||||
keep_input_mutations=True,
|
||||
kwargs={},
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, primals, tangents):
|
||||
primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1)
|
||||
mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = None
|
||||
copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = copy_ = None
|
||||
return pytree.tree_unflatten([mul, mul_1, None], self._out_spec)""",
|
||||
)
|
||||
|
||||
|
||||
class TestPartitioning(AOTTestCase):
|
||||
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
|
||||
|
|
|
|||
|
|
@ -1572,6 +1572,7 @@ def _aot_export_function(
|
|||
pre_dispatch: bool = False,
|
||||
# If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong.
|
||||
dynamic_shapes: Optional[bool] = None,
|
||||
keep_input_mutations: bool = False,
|
||||
kwargs=None,
|
||||
) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
|
||||
kwargs = kwargs or {}
|
||||
|
|
@ -1610,7 +1611,7 @@ def _aot_export_function(
|
|||
# For now there's no use case involving keeping input mutations in the graph
|
||||
# (which we can only do in the inference case anyway).
|
||||
# We can add this later if we need to.
|
||||
keep_inference_input_mutations=False,
|
||||
keep_inference_input_mutations=keep_input_mutations,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
aot_autograd_arg_pos_to_source=None,
|
||||
is_export=True,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user