mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Revert "Use weakref.proxy when saving module to internal dictionaries to not increase refcount (#76435)"
This reverts commit 1aa3cbb83b.
Reverted https://github.com/pytorch/pytorch/pull/76435 on behalf of https://github.com/jbschlosser
This commit is contained in:
parent
71d61bb78b
commit
d8b80edade
|
|
@ -10,8 +10,6 @@ import unittest.mock as mock
|
||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
import pickle
|
import pickle
|
||||||
import gc
|
|
||||||
import weakref
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import repeat, product
|
from itertools import repeat, product
|
||||||
from functools import reduce, partial
|
from functools import reduce, partial
|
||||||
|
|
@ -21521,23 +21519,6 @@ class TestStateDictHooks(TestCase):
|
||||||
m_load.load_state_dict(m_state_dict)
|
m_load.load_state_dict(m_state_dict)
|
||||||
self.assertEqual(2, hook_called)
|
self.assertEqual(2, hook_called)
|
||||||
|
|
||||||
def test_no_extra_ref_to_module(self):
|
|
||||||
try:
|
|
||||||
gc.disable()
|
|
||||||
m = nn.Linear(10, 10)
|
|
||||||
|
|
||||||
def hook_with_module(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
m._register_load_state_dict_pre_hook(hook_with_module, True)
|
|
||||||
weak_m = weakref.ref(m)
|
|
||||||
del m
|
|
||||||
|
|
||||||
self.assertEqual(weak_m(), None)
|
|
||||||
finally:
|
|
||||||
gc.enable()
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_state_dict_module_pre_hook(self):
|
def test_load_state_dict_module_pre_hook(self):
|
||||||
hook_called = 0
|
hook_called = 0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
from collections import OrderedDict, namedtuple
|
from collections import OrderedDict, namedtuple
|
||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -39,20 +38,6 @@ def _addindent(s_, numSpaces):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def _wrap_hook(hook, module):
|
|
||||||
weak_module = weakref.ref(module)
|
|
||||||
|
|
||||||
@functools.wraps(hook)
|
|
||||||
def inner(*args, **kwargs):
|
|
||||||
module = weak_module()
|
|
||||||
if module is None:
|
|
||||||
raise RuntimeError("You are trying to call hook of a dead object!")
|
|
||||||
else:
|
|
||||||
return hook(module, *args, **kwargs)
|
|
||||||
|
|
||||||
return inner
|
|
||||||
|
|
||||||
|
|
||||||
r"""This tracks hooks common to all modules that are executed before/after
|
r"""This tracks hooks common to all modules that are executed before/after
|
||||||
calling forward and backward. This is global state used for debugging/profiling
|
calling forward and backward. This is global state used for debugging/profiling
|
||||||
purposes"""
|
purposes"""
|
||||||
|
|
@ -1181,7 +1166,8 @@ class Module:
|
||||||
grad_fn = var.grad_fn
|
grad_fn = var.grad_fn
|
||||||
if grad_fn is not None:
|
if grad_fn is not None:
|
||||||
for hook in non_full_backward_hooks:
|
for hook in non_full_backward_hooks:
|
||||||
wrapper = _wrap_hook(hook, self)
|
wrapper = functools.partial(hook, self)
|
||||||
|
functools.update_wrapper(wrapper, hook)
|
||||||
grad_fn.register_hook(wrapper)
|
grad_fn.register_hook(wrapper)
|
||||||
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
|
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
|
||||||
|
|
||||||
|
|
@ -1416,7 +1402,7 @@ class Module:
|
||||||
"""
|
"""
|
||||||
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
|
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
|
||||||
if with_module:
|
if with_module:
|
||||||
hook = _wrap_hook(hook, self)
|
hook = functools.partial(hook, self)
|
||||||
self._load_state_dict_pre_hooks[handle.id] = hook
|
self._load_state_dict_pre_hooks[handle.id] = hook
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user