mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aot] fix deepcopying of aot bwd containing real tensors (#153999)
Previously when we lower backward AOT due to symints, the post grad passes would leave the bw_module in a non-runnable state. This caused issues when compiled autograd tried to trace at runtime. So we had inductor operate on a deepcopy of bw_module. But with https://github.com/pytorch/pytorch/issues/153993, we see that deepcopying real tensors will fail under fake mode due to the device type mismatch between the fake tensors ("meta" device) and the real tensor. So by disabling fake mode, we avoid these errors. This change is a strict improvement over current, but it does reveal that this deepcopy can theoretically cause OOMs. FIXES https://github.com/pytorch/pytorch/issues/153993 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153999 Approved by: https://github.com/jamesjwu, https://github.com/bdhirsh
This commit is contained in:
parent
67f9feeee7
commit
fae6f6c9ca
|
|
@ -7234,6 +7234,26 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
|||
output = capturedOutput.getvalue()
|
||||
self.assertNotIn("class GraphModule", output)
|
||||
|
||||
def test_deepcopy_constant_tensor_in_aot_bwd(self):
|
||||
class Fn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x + 1
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return grad_out * torch.tensor(2) * grad_out.shape[0]
|
||||
|
||||
def f(x):
|
||||
return Fn.apply(x)
|
||||
|
||||
x = torch.randn(8, requires_grad=True)
|
||||
out = f(x) # should not raise
|
||||
c_out = torch.compile(f, backend="aot_eager", dynamic=True)(x)
|
||||
expected = torch.autograd.grad(out.sum(), inputs=(x,))
|
||||
actual = torch.autograd.grad(c_out.sum(), inputs=(x,))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
|
|
|||
|
|
@ -1211,9 +1211,15 @@ def aot_dispatch_autograd(
|
|||
if num_symints_saved_for_bw > 0:
|
||||
try:
|
||||
# See Note: [Backward graph lazy lowering]
|
||||
with torch._subclasses.fake_tensor.unset_fake_temporarily():
|
||||
# If bw_module contains lifted constants, they will be real tensors stored as
|
||||
# GraphModule. Deepcopying tensors under fake mode is not supported and will
|
||||
# raise when attempting to set storage.
|
||||
bw_module_copy = copy.deepcopy(bw_module)
|
||||
compiled_bw_func = aot_config.bw_compiler(
|
||||
copy.deepcopy(bw_module), placeholder_list
|
||||
bw_module_copy, placeholder_list
|
||||
)
|
||||
del bw_module_copy
|
||||
except Exception as e:
|
||||
exc = e
|
||||
trace_structured(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user