[aot] fix deepcopying of aot bwd containing real tensors (#153999)

Previously when we lower backward AOT due to symints, the post grad passes would leave the bw_module in a non-runnable state. This caused issues when compiled autograd tried to trace at runtime. So we had inductor operate on a deepcopy of bw_module.

But with https://github.com/pytorch/pytorch/issues/153993, we see that deepcopying real tensors will fail under fake mode due to the device type mismatch between the fake tensors ("meta" device) and the real tensor. So by disabling fake mode, we avoid these errors. This change is a strict improvement over current, but it does reveal that this deepcopy can theoretically cause OOMs.

FIXES https://github.com/pytorch/pytorch/issues/153993

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153999
Approved by: https://github.com/jamesjwu, https://github.com/bdhirsh
This commit is contained in:
Simon Fan 2025-05-21 09:37:40 -07:00 committed by PyTorch MergeBot
parent 67f9feeee7
commit fae6f6c9ca
2 changed files with 27 additions and 1 deletions

View File

@ -7234,6 +7234,26 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
output = capturedOutput.getvalue()
self.assertNotIn("class GraphModule", output)
def test_deepcopy_constant_tensor_in_aot_bwd(self):
class Fn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x + 1
@staticmethod
def backward(ctx, grad_out):
return grad_out * torch.tensor(2) * grad_out.shape[0]
def f(x):
return Fn.apply(x)
x = torch.randn(8, requires_grad=True)
out = f(x) # should not raise
c_out = torch.compile(f, backend="aot_eager", dynamic=True)(x)
expected = torch.autograd.grad(out.sum(), inputs=(x,))
actual = torch.autograd.grad(c_out.sum(), inputs=(x,))
self.assertEqual(expected, actual)
instantiate_parametrized_tests(ReproTests)

View File

@ -1211,9 +1211,15 @@ def aot_dispatch_autograd(
if num_symints_saved_for_bw > 0:
try:
# See Note: [Backward graph lazy lowering]
with torch._subclasses.fake_tensor.unset_fake_temporarily():
# If bw_module contains lifted constants, they will be real tensors stored as
# GraphModule. Deepcopying tensors under fake mode is not supported and will
# raise when attempting to set storage.
bw_module_copy = copy.deepcopy(bw_module)
compiled_bw_func = aot_config.bw_compiler(
copy.deepcopy(bw_module), placeholder_list
bw_module_copy, placeholder_list
)
del bw_module_copy
except Exception as e:
exc = e
trace_structured(