mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
37 lines
1020 B
Python
37 lines
1020 B
Python
import collections
|
|
import weakref
|
|
|
|
|
|
class RemovableHandle(object):
|
|
"""A handle which provides the capability to remove a hook."""
|
|
|
|
next_id = 0
|
|
|
|
def __init__(self, hooks_dict):
|
|
self.hooks_dict_ref = weakref.ref(hooks_dict)
|
|
self.id = RemovableHandle.next_id
|
|
RemovableHandle.next_id += 1
|
|
|
|
def remove(self):
|
|
hooks_dict = self.hooks_dict_ref()
|
|
if hooks_dict is not None and self.id in hooks_dict:
|
|
del hooks_dict[self.id]
|
|
|
|
def __getstate__(self):
|
|
return (self.hooks_dict_ref(), self.id)
|
|
|
|
def __setstate__(self, state):
|
|
if state[0] is None:
|
|
# create a dead reference
|
|
self.hooks_dict_ref = weakref.ref(collections.OrderedDict())
|
|
else:
|
|
self.hooks_dict_ref = weakref.ref(state[0])
|
|
self.id = state[1]
|
|
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, type, value, tb):
|
|
self.remove()
|