mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add warning for module full backward hook when no input requires gradient (#155339)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155339 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
e25ce0f928
commit
671a9d175b
|
|
@ -1445,7 +1445,14 @@ class TestModuleHookNN(NNTestCase):
|
|||
mod.register_full_backward_hook(hook)
|
||||
|
||||
# This should run and trigger the hook properly
|
||||
mod(inp).sum().backward()
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
(
|
||||
"Full backward hook is firing when gradients are computed with "
|
||||
"respect to module outputs since no inputs require gradients"
|
||||
),
|
||||
):
|
||||
mod(inp).sum().backward()
|
||||
self.assertEqual(hook_called[0], 1)
|
||||
|
||||
return_val = "grad_input"
|
||||
|
|
|
|||
|
|
@ -223,6 +223,11 @@ class BackwardHook:
|
|||
# Special case if no input required gradients, this hook should call the user
|
||||
# hook directly
|
||||
if self.input_tensors_index is None:
|
||||
warnings.warn("Full backward hook is firing when gradients are computed "
|
||||
"with respect to module outputs since no inputs require gradients. See "
|
||||
"https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook " # noqa: B950
|
||||
"for more details.",
|
||||
stacklevel=5)
|
||||
grad_inputs = self._pack_with_none([], [], self.n_inputs)
|
||||
for user_hook in self.user_hooks:
|
||||
res = user_hook(self.module, grad_inputs, self.grad_outputs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user