Revert "Use weakref.proxy when saving module to internal dictionaries to not increase refcount (#76435)"

This reverts commit 1aa3cbb83b.

Reverted https://github.com/pytorch/pytorch/pull/76435 on behalf of https://github.com/jbschlosser
This commit is contained in:
PyTorch MergeBot 2022-05-17 17:51:26 +00:00
parent 71d61bb78b
commit d8b80edade
2 changed files with 3 additions and 36 deletions

View File

@ -10,8 +10,6 @@ import unittest.mock as mock
import itertools
import warnings
import pickle
import gc
import weakref
from copy import deepcopy
from itertools import repeat, product
from functools import reduce, partial
@ -21521,23 +21519,6 @@ class TestStateDictHooks(TestCase):
m_load.load_state_dict(m_state_dict)
self.assertEqual(2, hook_called)
def test_no_extra_ref_to_module(self):
try:
gc.disable()
m = nn.Linear(10, 10)
def hook_with_module(*args, **kwargs):
pass
m._register_load_state_dict_pre_hook(hook_with_module, True)
weak_m = weakref.ref(m)
del m
self.assertEqual(weak_m(), None)
finally:
gc.enable()
def test_load_state_dict_module_pre_hook(self):
hook_called = 0

View File

@ -1,7 +1,6 @@
from collections import OrderedDict, namedtuple
import itertools
import warnings
import weakref
import functools
import torch
@ -39,20 +38,6 @@ def _addindent(s_, numSpaces):
return s
def _wrap_hook(hook, module):
weak_module = weakref.ref(module)
@functools.wraps(hook)
def inner(*args, **kwargs):
module = weak_module()
if module is None:
raise RuntimeError("You are trying to call hook of a dead object!")
else:
return hook(module, *args, **kwargs)
return inner
r"""This tracks hooks common to all modules that are executed before/after
calling forward and backward. This is global state used for debugging/profiling
purposes"""
@ -1181,7 +1166,8 @@ class Module:
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = _wrap_hook(hook, self)
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
@ -1416,7 +1402,7 @@ class Module:
"""
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
if with_module:
hook = _wrap_hook(hook, self)
hook = functools.partial(hook, self)
self._load_state_dict_pre_hooks[handle.id] = hook
return handle