mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Add warning for module full backward hook when no input requires gradient (#155339)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155339 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
e25ce0f928
commit
671a9d175b
|
|
@ -1445,6 +1445,13 @@ class TestModuleHookNN(NNTestCase):
|
||||||
mod.register_full_backward_hook(hook)
|
mod.register_full_backward_hook(hook)
|
||||||
|
|
||||||
# This should run and trigger the hook properly
|
# This should run and trigger the hook properly
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
UserWarning,
|
||||||
|
(
|
||||||
|
"Full backward hook is firing when gradients are computed with "
|
||||||
|
"respect to module outputs since no inputs require gradients"
|
||||||
|
),
|
||||||
|
):
|
||||||
mod(inp).sum().backward()
|
mod(inp).sum().backward()
|
||||||
self.assertEqual(hook_called[0], 1)
|
self.assertEqual(hook_called[0], 1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -223,6 +223,11 @@ class BackwardHook:
|
||||||
# Special case if no input required gradients, this hook should call the user
|
# Special case if no input required gradients, this hook should call the user
|
||||||
# hook directly
|
# hook directly
|
||||||
if self.input_tensors_index is None:
|
if self.input_tensors_index is None:
|
||||||
|
warnings.warn("Full backward hook is firing when gradients are computed "
|
||||||
|
"with respect to module outputs since no inputs require gradients. See "
|
||||||
|
"https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook " # noqa: B950
|
||||||
|
"for more details.",
|
||||||
|
stacklevel=5)
|
||||||
grad_inputs = self._pack_with_none([], [], self.n_inputs)
|
grad_inputs = self._pack_with_none([], [], self.n_inputs)
|
||||||
for user_hook in self.user_hooks:
|
for user_hook in self.user_hooks:
|
||||||
res = user_hook(self.module, grad_inputs, self.grad_outputs)
|
res = user_hook(self.module, grad_inputs, self.grad_outputs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user