mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aotd] Support mutations in reordering_to_mimic_autograd_engine (#155353)
Original issue: https://github.com/pytorch/pytorch/issues/154820 Dedicated sub-issue: https://github.com/pytorch/pytorch/issues/155242 Backward graph is reordered by partitioners.py: reordering_to_mimic_autograd_engine Which only records in the backward graph compute that starts from tangents. Mutation of primals(inputs) in backward can be disconnected from backward. Handling this copy_ specifically, as we add this mutation in framework and this is the only mutation that exist. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155353 Approved by: https://github.com/bdhirsh, https://github.com/zou3519
This commit is contained in:
parent
6c05f2fca0
commit
0083032e75
|
|
@ -7779,6 +7779,43 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
|
|||
|
||||
self.assertEqual(out, optout)
|
||||
|
||||
def test_mutations_in_bw_detached_from_tangent(self):
|
||||
class AF(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, dummy, inplace_tensor):
|
||||
ctx.inplace_tensor = inplace_tensor
|
||||
return dummy.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inplace_tensor = ctx.inplace_tensor
|
||||
gradient_attachment = grad_output * 0 + 1
|
||||
inplace_tensor.add_(1 * gradient_attachment)
|
||||
return grad_output, None, None
|
||||
|
||||
def fn(dummy, inplace_tensor):
|
||||
return AF.apply(dummy, inplace_tensor)
|
||||
|
||||
def _inps():
|
||||
dummy = torch.zeros((2,), requires_grad=True)
|
||||
inplace_tensor = torch.zeros((2,), requires_grad=False)
|
||||
return dummy, inplace_tensor
|
||||
|
||||
inps = _inps()
|
||||
out = fn(*inps)
|
||||
ref_inps_after_fw = [x.clone().detach() for x in inps]
|
||||
out.sum().backward()
|
||||
ref_inps_after_bw = [x.clone().detach() for x in inps]
|
||||
|
||||
inps = _inps()
|
||||
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inps)
|
||||
inps_after_fw = [x.clone().detach() for x in inps]
|
||||
out.sum().backward()
|
||||
inps_after_bw = [x.clone().detach() for x in inps]
|
||||
|
||||
self.assertEqual(ref_inps_after_fw, inps_after_fw)
|
||||
self.assertEqual(ref_inps_after_bw, inps_after_bw)
|
||||
|
||||
|
||||
class MockFXGraphCache:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1141,6 +1141,11 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
|||
return gm
|
||||
|
||||
# Build the graph op-by-op by starting from the node all the way to the end
|
||||
# copy_ can be not using tangents at all, we must copy it.
|
||||
for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]:
|
||||
if node.op == "call_function" and node.target == torch.ops.aten.copy_.default:
|
||||
insert_node_in_graph(node)
|
||||
|
||||
for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]:
|
||||
insert_node_in_graph(node)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user