From 671a9d175b7e15b4947c97f50b931dba1cc35a63 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 6 Jun 2025 13:48:39 -0700 Subject: [PATCH] Add warning for module full backward hook when no input requires gradient (#155339) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155339 Approved by: https://github.com/Skylion007 --- test/nn/test_module_hooks.py | 9 ++++++++- torch/utils/hooks.py | 5 +++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index c9c29f0ba4a..72e3665cfdd 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -1445,7 +1445,14 @@ class TestModuleHookNN(NNTestCase): mod.register_full_backward_hook(hook) # This should run and trigger the hook properly - mod(inp).sum().backward() + with self.assertWarnsRegex( + UserWarning, + ( + "Full backward hook is firing when gradients are computed with " + "respect to module outputs since no inputs require gradients" + ), + ): + mod(inp).sum().backward() self.assertEqual(hook_called[0], 1) return_val = "grad_input" diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index c41add0fcbe..e6e93966afd 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -223,6 +223,11 @@ class BackwardHook: # Special case if no input required gradients, this hook should call the user # hook directly if self.input_tensors_index is None: + warnings.warn("Full backward hook is firing when gradients are computed " + "with respect to module outputs since no inputs require gradients. See " + "https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook " # noqa: B950 + "for more details.", + stacklevel=5) grad_inputs = self._pack_with_none([], [], self.n_inputs) for user_hook in self.user_hooks: res = user_hook(self.module, grad_inputs, self.grad_outputs)