[aotd] Support mutations in reordering_to_mimic_autograd_engine (#155353)

Original issue: https://github.com/pytorch/pytorch/issues/154820

Dedicated sub-issue: https://github.com/pytorch/pytorch/issues/155242

Backward graph is reordered by partitioners.py: reordering_to_mimic_autograd_engine

Which only records in the backward graph compute that starts from tangents.

Mutation of primals(inputs) in backward can be disconnected from backward.

Handling this copy_ specifically, as we  add this mutation in framework and this is the only mutation that exist.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155353
Approved by: https://github.com/bdhirsh, https://github.com/zou3519
This commit is contained in:
IvanKobzarev 2025-06-06 12:48:39 -07:00 committed by PyTorch MergeBot
parent 6c05f2fca0
commit 0083032e75
2 changed files with 42 additions and 0 deletions

View File

@ -7779,6 +7779,43 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
self.assertEqual(out, optout)
def test_mutations_in_bw_detached_from_tangent(self):
class AF(torch.autograd.Function):
@staticmethod
def forward(ctx, dummy, inplace_tensor):
ctx.inplace_tensor = inplace_tensor
return dummy.clone()
@staticmethod
def backward(ctx, grad_output):
inplace_tensor = ctx.inplace_tensor
gradient_attachment = grad_output * 0 + 1
inplace_tensor.add_(1 * gradient_attachment)
return grad_output, None, None
def fn(dummy, inplace_tensor):
return AF.apply(dummy, inplace_tensor)
def _inps():
dummy = torch.zeros((2,), requires_grad=True)
inplace_tensor = torch.zeros((2,), requires_grad=False)
return dummy, inplace_tensor
inps = _inps()
out = fn(*inps)
ref_inps_after_fw = [x.clone().detach() for x in inps]
out.sum().backward()
ref_inps_after_bw = [x.clone().detach() for x in inps]
inps = _inps()
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inps)
inps_after_fw = [x.clone().detach() for x in inps]
out.sum().backward()
inps_after_bw = [x.clone().detach() for x in inps]
self.assertEqual(ref_inps_after_fw, inps_after_fw)
self.assertEqual(ref_inps_after_bw, inps_after_bw)
class MockFXGraphCache:
"""

View File

@ -1141,6 +1141,11 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
return gm
# Build the graph op-by-op by starting from the node all the way to the end
# copy_ can be not using tangents at all, we must copy it.
for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]:
if node.op == "call_function" and node.target == torch.ops.aten.copy_.default:
insert_node_in_graph(node)
for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]:
insert_node_in_graph(node)