mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Use weakref.proxy when saving module to internal dictionaries to not increase refcount (#76435)"
This reverts commit 1aa3cbb83b.
Reverted https://github.com/pytorch/pytorch/pull/76435 on behalf of https://github.com/jbschlosser
This commit is contained in:
parent
71d61bb78b
commit
d8b80edade
|
|
@ -10,8 +10,6 @@ import unittest.mock as mock
|
|||
import itertools
|
||||
import warnings
|
||||
import pickle
|
||||
import gc
|
||||
import weakref
|
||||
from copy import deepcopy
|
||||
from itertools import repeat, product
|
||||
from functools import reduce, partial
|
||||
|
|
@ -21521,23 +21519,6 @@ class TestStateDictHooks(TestCase):
|
|||
m_load.load_state_dict(m_state_dict)
|
||||
self.assertEqual(2, hook_called)
|
||||
|
||||
def test_no_extra_ref_to_module(self):
|
||||
try:
|
||||
gc.disable()
|
||||
m = nn.Linear(10, 10)
|
||||
|
||||
def hook_with_module(*args, **kwargs):
|
||||
pass
|
||||
|
||||
m._register_load_state_dict_pre_hook(hook_with_module, True)
|
||||
weak_m = weakref.ref(m)
|
||||
del m
|
||||
|
||||
self.assertEqual(weak_m(), None)
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
|
||||
def test_load_state_dict_module_pre_hook(self):
|
||||
hook_called = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from collections import OrderedDict, namedtuple
|
||||
import itertools
|
||||
import warnings
|
||||
import weakref
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
|
@ -39,20 +38,6 @@ def _addindent(s_, numSpaces):
|
|||
return s
|
||||
|
||||
|
||||
def _wrap_hook(hook, module):
|
||||
weak_module = weakref.ref(module)
|
||||
|
||||
@functools.wraps(hook)
|
||||
def inner(*args, **kwargs):
|
||||
module = weak_module()
|
||||
if module is None:
|
||||
raise RuntimeError("You are trying to call hook of a dead object!")
|
||||
else:
|
||||
return hook(module, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
r"""This tracks hooks common to all modules that are executed before/after
|
||||
calling forward and backward. This is global state used for debugging/profiling
|
||||
purposes"""
|
||||
|
|
@ -1181,7 +1166,8 @@ class Module:
|
|||
grad_fn = var.grad_fn
|
||||
if grad_fn is not None:
|
||||
for hook in non_full_backward_hooks:
|
||||
wrapper = _wrap_hook(hook, self)
|
||||
wrapper = functools.partial(hook, self)
|
||||
functools.update_wrapper(wrapper, hook)
|
||||
grad_fn.register_hook(wrapper)
|
||||
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
|
||||
|
||||
|
|
@ -1416,7 +1402,7 @@ class Module:
|
|||
"""
|
||||
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
|
||||
if with_module:
|
||||
hook = _wrap_hook(hook, self)
|
||||
hook = functools.partial(hook, self)
|
||||
self._load_state_dict_pre_hooks[handle.id] = hook
|
||||
return handle
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user