mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[compiled autograd] fix saved tensor hook firing count (#134361)
SavedVariable constructor calls the pack hooks, we don't want to call them for the proxy tensor since it is proxying a tensor that already had called the pack hook during forward. Using the same fix as https://github.com/pytorch/pytorch/pull/123196 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134361 Approved by: https://github.com/jansel ghstack dependencies: #134186, #134200, #134205, #134286, #134290, #134162, #134163
This commit is contained in:
parent
929de1d0d4
commit
ff7d94c67e
|
|
@ -2430,6 +2430,44 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
|||
|
||||
self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_saved_tensor_unpack_hook_ordering(self):
|
||||
# not the correct behaviour, I'm just preventing this from changing silently
|
||||
def f(x, y):
|
||||
return x * y
|
||||
|
||||
pack_count = 0
|
||||
unpack_count = 0
|
||||
|
||||
def pack_hook(x):
|
||||
nonlocal pack_count
|
||||
pack_count += 1
|
||||
return x
|
||||
|
||||
def unpack_hook(x):
|
||||
nonlocal unpack_count
|
||||
unpack_count += 1
|
||||
return x
|
||||
|
||||
def tensor_hook(_):
|
||||
# in eager, tensor_hook is fired before unpack_hook
|
||||
# but in compiled autograd, tensor_hook is lifted whereas unpack_hook is not
|
||||
self.assertEqual(unpack_count, 0)
|
||||
|
||||
x = torch.ones(4, requires_grad=True)
|
||||
y = torch.ones(4, requires_grad=False)
|
||||
with torch.autograd.graph.saved_tensors_hooks(
|
||||
pack_hook, unpack_hook
|
||||
), compiled_autograd.enable(make_compiler_fn(fullgraph=False)):
|
||||
out_test = f(x, y)
|
||||
self.assertEqual(pack_count, 1)
|
||||
self.assertEqual(unpack_count, 0)
|
||||
loss = out_test.sum()
|
||||
loss.register_hook(tensor_hook)
|
||||
loss.backward()
|
||||
self.assertEqual(pack_count, 1)
|
||||
self.assertEqual(unpack_count, 1)
|
||||
|
||||
|
||||
def load_test_module(name):
|
||||
testdir = Path(__file__).absolute().parent.parent
|
||||
|
|
@ -2619,7 +2657,6 @@ known_failing_tests = {
|
|||
# Category: Divergence from eager
|
||||
"test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
|
||||
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
|
||||
"test_unpack_hooks_exec_count", # saved tensor packed twice
|
||||
# Uncategorized
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -599,8 +599,10 @@ class SwapSavedVariables {
|
|||
TensorArg& arg = compiler.tensor_args.lookup(t);
|
||||
stashed_variables.save(&t, std::move(t));
|
||||
if (arg.defined()) {
|
||||
bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
|
||||
TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
|
||||
t = SavedVariable(arg.proxy_tensor, false);
|
||||
at::SavedTensorDefaultHooks::set_tracing(prior);
|
||||
}
|
||||
}
|
||||
void after(SavedVariable& t) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user