from __future__ import absolute_import, division, print_function, unicode_literals import collections import weakref import warnings class RemovableHandle(object): """A handle which provides the capability to remove a hook.""" next_id = 0 def __init__(self, hooks_dict): self.hooks_dict_ref = weakref.ref(hooks_dict) self.id = RemovableHandle.next_id RemovableHandle.next_id += 1 def remove(self): hooks_dict = self.hooks_dict_ref() if hooks_dict is not None and self.id in hooks_dict: del hooks_dict[self.id] def __getstate__(self): return (self.hooks_dict_ref(), self.id) def __setstate__(self, state): if state[0] is None: # create a dead reference self.hooks_dict_ref = weakref.ref(collections.OrderedDict()) else: self.hooks_dict_ref = weakref.ref(state[0]) self.id = state[1] RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1) def __enter__(self): return self def __exit__(self, type, value, tb): self.remove() def unserializable_hook(f): """ Decorator which marks a function as an unserializable hook. This suppresses warnings that would otherwise arise if you attempt to serialize a tensor that has a hook. """ f.__torch_unserializable__ = True return f def warn_if_has_hooks(tensor): if tensor._backward_hooks: for k in tensor._backward_hooks: hook = tensor._backward_hooks[k] if not hasattr(k, "__torch_unserializable__"): warnings.warn("backward hook {} on tensor will not be " "serialized. If this is expected, you can " "decorate the function with @torch.utils.hooks.unserializable_hook " "to suppress this warning".format(repr(hook)))