mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/598 This is BC-breaking as we now explicitly don't call the hook when there are not Tensors at the top level of the output. This feature was not working anyways as the returned grad_input/grad_output were wrong (not respecting the output structure and wrong inputs for multi-Node Module). This is also BC-breaking as we now report the correct gradients for `nn.Module`s that contain multiple autograd `Node`s while we use to return bad results before. Pull Request resolved: https://github.com/pytorch/pytorch/pull/46163 Reviewed By: ailzhang, mruberry Differential Revision: D24894180 Pulled By: albanD fbshipit-source-id: e1b5d193d2818eb2f51e2a2722c7405c8bd13c2b
183 lines
6.5 KiB
Python
183 lines
6.5 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
import torch
|
|
from collections import OrderedDict
|
|
import weakref
|
|
import warnings
|
|
import functools
|
|
from typing import Any
|
|
|
|
|
|
class RemovableHandle(object):
|
|
"""A handle which provides the capability to remove a hook."""
|
|
|
|
id: int
|
|
next_id: int = 0
|
|
|
|
def __init__(self, hooks_dict: Any) -> None:
|
|
self.hooks_dict_ref = weakref.ref(hooks_dict)
|
|
self.id = RemovableHandle.next_id
|
|
RemovableHandle.next_id += 1
|
|
|
|
def remove(self) -> None:
|
|
hooks_dict = self.hooks_dict_ref()
|
|
if hooks_dict is not None and self.id in hooks_dict:
|
|
del hooks_dict[self.id]
|
|
|
|
def __getstate__(self):
|
|
return (self.hooks_dict_ref(), self.id)
|
|
|
|
def __setstate__(self, state) -> None:
|
|
if state[0] is None:
|
|
# create a dead reference
|
|
self.hooks_dict_ref = weakref.ref(OrderedDict())
|
|
else:
|
|
self.hooks_dict_ref = weakref.ref(state[0])
|
|
self.id = state[1]
|
|
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
|
|
|
|
def __enter__(self) -> 'RemovableHandle':
|
|
return self
|
|
|
|
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
|
|
self.remove()
|
|
|
|
|
|
def unserializable_hook(f):
|
|
"""
|
|
Decorator which marks a function as an unserializable hook.
|
|
This suppresses warnings that would otherwise arise if you attempt
|
|
to serialize a tensor that has a hook.
|
|
"""
|
|
f.__torch_unserializable__ = True
|
|
return f
|
|
|
|
|
|
def warn_if_has_hooks(tensor):
|
|
if tensor._backward_hooks:
|
|
for k in tensor._backward_hooks:
|
|
hook = tensor._backward_hooks[k]
|
|
if not hasattr(k, "__torch_unserializable__"):
|
|
warnings.warn("backward hook {} on tensor will not be "
|
|
"serialized. If this is expected, you can "
|
|
"decorate the function with @torch.utils.hooks.unserializable_hook "
|
|
"to suppress this warning".format(repr(hook)))
|
|
|
|
class BackwardHook(object):
|
|
"""
|
|
A wrapper class to implement nn.Module backward hooks.
|
|
It handles:
|
|
- Ignoring non-Tensor inputs and replacing them by None before calling the user hook
|
|
- Generating the proper Node to capture a set of Tensor's gradients
|
|
- Linking the gradients captures for the outputs with the gradients captured for the input
|
|
- Calling the user hook once both output and input gradients are available
|
|
"""
|
|
|
|
def __init__(self, module, user_hooks):
|
|
self.user_hooks = user_hooks
|
|
self.module = module
|
|
|
|
self.grad_outputs = None
|
|
self.n_outputs = -1
|
|
self.output_tensors_index = None
|
|
self.n_inputs = -1
|
|
self.input_tensors_index = None
|
|
|
|
def _pack_with_none(self, indices, values, size):
|
|
res = [None] * size
|
|
for idx, val in zip(indices, values):
|
|
res[idx] = val
|
|
|
|
return tuple(res)
|
|
|
|
def _unpack_none(self, indices, values):
|
|
res = []
|
|
for idx in indices:
|
|
res.append(values[idx])
|
|
|
|
return tuple(res)
|
|
|
|
def _set_user_hook(self, grad_fn, user_hook):
|
|
@functools.wraps(user_hook)
|
|
def hook(grad_input, _):
|
|
if self.grad_outputs is None:
|
|
raise RuntimeError("Module backward hook for grad_input is called before "
|
|
"the grad_output one. This happens because the gradient "
|
|
"in your nn.Module flows to the Module's input without "
|
|
"passing through the Module's output. Make sure that the "
|
|
"output depends on the input and that the loss is computed "
|
|
"based on the output.")
|
|
|
|
grad_input = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)
|
|
res = user_hook(self.module, grad_input, self.grad_outputs)
|
|
if res is None:
|
|
return res
|
|
|
|
if len(res) != len(grad_input):
|
|
raise RuntimeError("Backward hook returned an invalid number of grad_input, "
|
|
"got {}, but expected {}".format(len(res), len(grad_input)))
|
|
return self._unpack_none(self.input_tensors_index, res)
|
|
grad_fn.register_hook(hook)
|
|
|
|
def _apply_on_tensors(self, fn, args):
|
|
# Can be used to apply the given function to the tensors contained in the
|
|
# args. Will return updated args and the tensors indices
|
|
tensors_idx = []
|
|
tensors = []
|
|
|
|
requires_grad = False
|
|
for i, arg in enumerate(args):
|
|
if isinstance(arg, torch.Tensor):
|
|
tensors_idx.append(i)
|
|
tensors.append(arg)
|
|
requires_grad |= arg.requires_grad
|
|
|
|
if not requires_grad:
|
|
return args, None
|
|
|
|
new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
|
|
if len(new_tensors) == 0:
|
|
raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.")
|
|
grad_fn = new_tensors[0].grad_fn
|
|
if not grad_fn.name() == "BackwardHookFunctionBackward":
|
|
raise RuntimeError("Error while setting up backward hooks. Please open "
|
|
"an issue with a code sample to reproduce this.")
|
|
|
|
fn(grad_fn)
|
|
|
|
arg_list = list(args)
|
|
for idx, val in zip(tensors_idx, new_tensors):
|
|
arg_list[idx] = val
|
|
|
|
return tuple(arg_list), tensors_idx
|
|
|
|
def setup_input_hook(self, args):
|
|
def fn(grad_fn):
|
|
for hook in self.user_hooks:
|
|
self._set_user_hook(grad_fn, hook)
|
|
|
|
res, input_idx = self._apply_on_tensors(fn, args)
|
|
self.n_inputs = len(args)
|
|
self.input_tensors_index = input_idx
|
|
return res
|
|
|
|
def setup_output_hook(self, args):
|
|
def fn(grad_fn):
|
|
def hook(_, grad_output):
|
|
self.grad_outputs = self._pack_with_none(self.output_tensors_index,
|
|
grad_output,
|
|
self.n_outputs)
|
|
grad_fn.register_hook(hook)
|
|
|
|
is_tuple = True
|
|
if not isinstance(args, tuple):
|
|
args = (args,)
|
|
is_tuple = False
|
|
|
|
res, output_idx = self._apply_on_tensors(fn, args)
|
|
self.n_outputs = len(args)
|
|
self.output_tensors_index = output_idx
|
|
|
|
if not is_tuple:
|
|
res = res[0]
|
|
return res
|