diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 33c1da771fe..c5ad363343a 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -7779,6 +7779,43 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd): self.assertEqual(out, optout) + def test_mutations_in_bw_detached_from_tangent(self): + class AF(torch.autograd.Function): + @staticmethod + def forward(ctx, dummy, inplace_tensor): + ctx.inplace_tensor = inplace_tensor + return dummy.clone() + + @staticmethod + def backward(ctx, grad_output): + inplace_tensor = ctx.inplace_tensor + gradient_attachment = grad_output * 0 + 1 + inplace_tensor.add_(1 * gradient_attachment) + return grad_output, None, None + + def fn(dummy, inplace_tensor): + return AF.apply(dummy, inplace_tensor) + + def _inps(): + dummy = torch.zeros((2,), requires_grad=True) + inplace_tensor = torch.zeros((2,), requires_grad=False) + return dummy, inplace_tensor + + inps = _inps() + out = fn(*inps) + ref_inps_after_fw = [x.clone().detach() for x in inps] + out.sum().backward() + ref_inps_after_bw = [x.clone().detach() for x in inps] + + inps = _inps() + out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inps) + inps_after_fw = [x.clone().detach() for x in inps] + out.sum().backward() + inps_after_bw = [x.clone().detach() for x in inps] + + self.assertEqual(ref_inps_after_fw, inps_after_fw) + self.assertEqual(ref_inps_after_bw, inps_after_bw) + class MockFXGraphCache: """ diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 826906ec0ba..60d125116a6 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1141,6 +1141,11 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: return gm # Build the graph op-by-op by starting from the node all the way to the end + # copy_ can be not using tangents at all, we must copy it. + for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]: + if node.op == "call_function" and node.target == torch.ops.aten.copy_.default: + insert_node_in_graph(node) + for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]: insert_node_in_graph(node)