[compiled autograd] fix saved tensor hook firing count (#134361)

SavedVariable constructor calls the pack hooks, we don't want to call them for the proxy tensor since it is proxying a tensor that already had called the pack hook during forward.

Using the same fix as https://github.com/pytorch/pytorch/pull/123196

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134361
Approved by: https://github.com/jansel
ghstack dependencies: #134186, #134200, #134205, #134286, #134290, #134162, #134163
This commit is contained in:
Simon Fan 2024-08-23 12:43:41 -07:00 committed by PyTorch MergeBot
parent 929de1d0d4
commit ff7d94c67e
2 changed files with 40 additions and 1 deletions

View File

@ -2430,6 +2430,44 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0)
@unittest.expectedFailure
def test_saved_tensor_unpack_hook_ordering(self):
# not the correct behaviour, I'm just preventing this from changing silently
def f(x, y):
return x * y
pack_count = 0
unpack_count = 0
def pack_hook(x):
nonlocal pack_count
pack_count += 1
return x
def unpack_hook(x):
nonlocal unpack_count
unpack_count += 1
return x
def tensor_hook(_):
# in eager, tensor_hook is fired before unpack_hook
# but in compiled autograd, tensor_hook is lifted whereas unpack_hook is not
self.assertEqual(unpack_count, 0)
x = torch.ones(4, requires_grad=True)
y = torch.ones(4, requires_grad=False)
with torch.autograd.graph.saved_tensors_hooks(
pack_hook, unpack_hook
), compiled_autograd.enable(make_compiler_fn(fullgraph=False)):
out_test = f(x, y)
self.assertEqual(pack_count, 1)
self.assertEqual(unpack_count, 0)
loss = out_test.sum()
loss.register_hook(tensor_hook)
loss.backward()
self.assertEqual(pack_count, 1)
self.assertEqual(unpack_count, 1)
def load_test_module(name):
testdir = Path(__file__).absolute().parent.parent
@ -2619,7 +2657,6 @@ known_failing_tests = {
# Category: Divergence from eager
"test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
"test_unpack_hooks_exec_count", # saved tensor packed twice
# Uncategorized
}

View File

@ -599,8 +599,10 @@ class SwapSavedVariables {
TensorArg& arg = compiler.tensor_args.lookup(t);
stashed_variables.save(&t, std::move(t));
if (arg.defined()) {
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
t = SavedVariable(arg.proxy_tensor, false);
at::SavedTensorDefaultHooks::set_tracing(prior);
}
}
void after(SavedVariable& t) {