pytorch/torch/utils/hooks.py
albanD ccd646696b Fix Module backward hooks for all Tensor inputs/outputs (#46163)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/598

This is BC-breaking as we now explicitly don't call the hook when there are not Tensors at the top level of the output.
This feature was not working anyways as the returned grad_input/grad_output were wrong (not respecting the output structure and wrong inputs for multi-Node Module).

This is also BC-breaking as we now report the correct gradients for `nn.Module`s that contain multiple autograd `Node`s while we use to return bad results before.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46163

Reviewed By: ailzhang, mruberry

Differential Revision: D24894180

Pulled By: albanD

fbshipit-source-id: e1b5d193d2818eb2f51e2a2722c7405c8bd13c2b
2020-12-18 09:04:36 -08:00

183 lines
6.5 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from collections import OrderedDict
import weakref
import warnings
import functools
from typing import Any
class RemovableHandle(object):
"""A handle which provides the capability to remove a hook."""
id: int
next_id: int = 0
def __init__(self, hooks_dict: Any) -> None:
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.id = RemovableHandle.next_id
RemovableHandle.next_id += 1
def remove(self) -> None:
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.id in hooks_dict:
del hooks_dict[self.id]
def __getstate__(self):
return (self.hooks_dict_ref(), self.id)
def __setstate__(self, state) -> None:
if state[0] is None:
# create a dead reference
self.hooks_dict_ref = weakref.ref(OrderedDict())
else:
self.hooks_dict_ref = weakref.ref(state[0])
self.id = state[1]
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
def __enter__(self) -> 'RemovableHandle':
return self
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
self.remove()
def unserializable_hook(f):
"""
Decorator which marks a function as an unserializable hook.
This suppresses warnings that would otherwise arise if you attempt
to serialize a tensor that has a hook.
"""
f.__torch_unserializable__ = True
return f
def warn_if_has_hooks(tensor):
if tensor._backward_hooks:
for k in tensor._backward_hooks:
hook = tensor._backward_hooks[k]
if not hasattr(k, "__torch_unserializable__"):
warnings.warn("backward hook {} on tensor will not be "
"serialized. If this is expected, you can "
"decorate the function with @torch.utils.hooks.unserializable_hook "
"to suppress this warning".format(repr(hook)))
class BackwardHook(object):
"""
A wrapper class to implement nn.Module backward hooks.
It handles:
- Ignoring non-Tensor inputs and replacing them by None before calling the user hook
- Generating the proper Node to capture a set of Tensor's gradients
- Linking the gradients captures for the outputs with the gradients captured for the input
- Calling the user hook once both output and input gradients are available
"""
def __init__(self, module, user_hooks):
self.user_hooks = user_hooks
self.module = module
self.grad_outputs = None
self.n_outputs = -1
self.output_tensors_index = None
self.n_inputs = -1
self.input_tensors_index = None
def _pack_with_none(self, indices, values, size):
res = [None] * size
for idx, val in zip(indices, values):
res[idx] = val
return tuple(res)
def _unpack_none(self, indices, values):
res = []
for idx in indices:
res.append(values[idx])
return tuple(res)
def _set_user_hook(self, grad_fn, user_hook):
@functools.wraps(user_hook)
def hook(grad_input, _):
if self.grad_outputs is None:
raise RuntimeError("Module backward hook for grad_input is called before "
"the grad_output one. This happens because the gradient "
"in your nn.Module flows to the Module's input without "
"passing through the Module's output. Make sure that the "
"output depends on the input and that the loss is computed "
"based on the output.")
grad_input = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)
res = user_hook(self.module, grad_input, self.grad_outputs)
if res is None:
return res
if len(res) != len(grad_input):
raise RuntimeError("Backward hook returned an invalid number of grad_input, "
"got {}, but expected {}".format(len(res), len(grad_input)))
return self._unpack_none(self.input_tensors_index, res)
grad_fn.register_hook(hook)
def _apply_on_tensors(self, fn, args):
# Can be used to apply the given function to the tensors contained in the
# args. Will return updated args and the tensors indices
tensors_idx = []
tensors = []
requires_grad = False
for i, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
tensors_idx.append(i)
tensors.append(arg)
requires_grad |= arg.requires_grad
if not requires_grad:
return args, None
new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
if len(new_tensors) == 0:
raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.")
grad_fn = new_tensors[0].grad_fn
if not grad_fn.name() == "BackwardHookFunctionBackward":
raise RuntimeError("Error while setting up backward hooks. Please open "
"an issue with a code sample to reproduce this.")
fn(grad_fn)
arg_list = list(args)
for idx, val in zip(tensors_idx, new_tensors):
arg_list[idx] = val
return tuple(arg_list), tensors_idx
def setup_input_hook(self, args):
def fn(grad_fn):
for hook in self.user_hooks:
self._set_user_hook(grad_fn, hook)
res, input_idx = self._apply_on_tensors(fn, args)
self.n_inputs = len(args)
self.input_tensors_index = input_idx
return res
def setup_output_hook(self, args):
def fn(grad_fn):
def hook(_, grad_output):
self.grad_outputs = self._pack_with_none(self.output_tensors_index,
grad_output,
self.n_outputs)
grad_fn.register_hook(hook)
is_tuple = True
if not isinstance(args, tuple):
args = (args,)
is_tuple = False
res, output_idx = self._apply_on_tensors(fn, args)
self.n_outputs = len(args)
self.output_tensors_index = output_idx
if not is_tuple:
res = res[0]
return res