mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43406 Reviewed By: mruberry Differential Revision: D23319736 Pulled By: malfet fbshipit-source-id: e25fbb49f27aa4893590b022441303d6d98263a9
62 lines
2.0 KiB
Python
62 lines
2.0 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
from collections import OrderedDict
|
|
import weakref
|
|
import warnings
|
|
from typing import Any
|
|
|
|
|
|
class RemovableHandle(object):
|
|
"""A handle which provides the capability to remove a hook."""
|
|
|
|
id: int
|
|
next_id: int = 0
|
|
|
|
def __init__(self, hooks_dict: Any) -> None:
|
|
self.hooks_dict_ref = weakref.ref(hooks_dict)
|
|
self.id = RemovableHandle.next_id
|
|
RemovableHandle.next_id += 1
|
|
|
|
def remove(self) -> None:
|
|
hooks_dict = self.hooks_dict_ref()
|
|
if hooks_dict is not None and self.id in hooks_dict:
|
|
del hooks_dict[self.id]
|
|
|
|
def __getstate__(self):
|
|
return (self.hooks_dict_ref(), self.id)
|
|
|
|
def __setstate__(self, state) -> None:
|
|
if state[0] is None:
|
|
# create a dead reference
|
|
self.hooks_dict_ref = weakref.ref(OrderedDict())
|
|
else:
|
|
self.hooks_dict_ref = weakref.ref(state[0])
|
|
self.id = state[1]
|
|
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
|
|
|
|
def __enter__(self) -> 'RemovableHandle':
|
|
return self
|
|
|
|
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
|
|
self.remove()
|
|
|
|
|
|
def unserializable_hook(f):
|
|
"""
|
|
Decorator which marks a function as an unserializable hook.
|
|
This suppresses warnings that would otherwise arise if you attempt
|
|
to serialize a tensor that has a hook.
|
|
"""
|
|
f.__torch_unserializable__ = True
|
|
return f
|
|
|
|
|
|
def warn_if_has_hooks(tensor):
|
|
if tensor._backward_hooks:
|
|
for k in tensor._backward_hooks:
|
|
hook = tensor._backward_hooks[k]
|
|
if not hasattr(k, "__torch_unserializable__"):
|
|
warnings.warn("backward hook {} on tensor will not be "
|
|
"serialized. If this is expected, you can "
|
|
"decorate the function with @torch.utils.hooks.unserializable_hook "
|
|
"to suppress this warning".format(repr(hook)))
|