Add warning for module full backward hook when no input requires gradient (#155339)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155339
Approved by: https://github.com/Skylion007
This commit is contained in:
Mikayla Gawarecki 2025-06-06 13:48:39 -07:00 committed by PyTorch MergeBot
parent e25ce0f928
commit 671a9d175b
2 changed files with 13 additions and 1 deletions

View File

@ -1445,6 +1445,13 @@ class TestModuleHookNN(NNTestCase):
mod.register_full_backward_hook(hook) mod.register_full_backward_hook(hook)
# This should run and trigger the hook properly # This should run and trigger the hook properly
with self.assertWarnsRegex(
UserWarning,
(
"Full backward hook is firing when gradients are computed with "
"respect to module outputs since no inputs require gradients"
),
):
mod(inp).sum().backward() mod(inp).sum().backward()
self.assertEqual(hook_called[0], 1) self.assertEqual(hook_called[0], 1)

View File

@ -223,6 +223,11 @@ class BackwardHook:
# Special case if no input required gradients, this hook should call the user # Special case if no input required gradients, this hook should call the user
# hook directly # hook directly
if self.input_tensors_index is None: if self.input_tensors_index is None:
warnings.warn("Full backward hook is firing when gradients are computed "
"with respect to module outputs since no inputs require gradients. See "
"https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook " # noqa: B950
"for more details.",
stacklevel=5)
grad_inputs = self._pack_with_none([], [], self.n_inputs) grad_inputs = self._pack_with_none([], [], self.n_inputs)
for user_hook in self.user_hooks: for user_hook in self.user_hooks:
res = user_hook(self.module, grad_inputs, self.grad_outputs) res = user_hook(self.module, grad_inputs, self.grad_outputs)