diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 9e0e1fe29a1..e7811576c74 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7234,6 +7234,26 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase): output = capturedOutput.getvalue() self.assertNotIn("class GraphModule", output) + def test_deepcopy_constant_tensor_in_aot_bwd(self): + class Fn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + 1 + + @staticmethod + def backward(ctx, grad_out): + return grad_out * torch.tensor(2) * grad_out.shape[0] + + def f(x): + return Fn.apply(x) + + x = torch.randn(8, requires_grad=True) + out = f(x) # should not raise + c_out = torch.compile(f, backend="aot_eager", dynamic=True)(x) + expected = torch.autograd.grad(out.sum(), inputs=(x,)) + actual = torch.autograd.grad(c_out.sum(), inputs=(x,)) + self.assertEqual(expected, actual) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 447e3f9c70d..56303d6752b 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -1211,9 +1211,15 @@ def aot_dispatch_autograd( if num_symints_saved_for_bw > 0: try: # See Note: [Backward graph lazy lowering] + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + # If bw_module contains lifted constants, they will be real tensors stored as + # GraphModule. Deepcopying tensors under fake mode is not supported and will + # raise when attempting to set storage. + bw_module_copy = copy.deepcopy(bw_module) compiled_bw_func = aot_config.bw_compiler( - copy.deepcopy(bw_module), placeholder_list + bw_module_copy, placeholder_list ) + del bw_module_copy except Exception as e: exc = e trace_structured(