Revert "Use weakref.proxy when saving module to internal dictionaries to not increase refcount (#76435)"

This reverts commit 1aa3cbb83b.

Reverted https://github.com/pytorch/pytorch/pull/76435 on behalf of https://github.com/jbschlosser
This commit is contained in:
PyTorch MergeBot 2022-05-17 17:51:26 +00:00
parent 71d61bb78b
commit d8b80edade
2 changed files with 3 additions and 36 deletions

View File

@ -10,8 +10,6 @@ import unittest.mock as mock
import itertools import itertools
import warnings import warnings
import pickle import pickle
import gc
import weakref
from copy import deepcopy from copy import deepcopy
from itertools import repeat, product from itertools import repeat, product
from functools import reduce, partial from functools import reduce, partial
@ -21521,23 +21519,6 @@ class TestStateDictHooks(TestCase):
m_load.load_state_dict(m_state_dict) m_load.load_state_dict(m_state_dict)
self.assertEqual(2, hook_called) self.assertEqual(2, hook_called)
def test_no_extra_ref_to_module(self):
try:
gc.disable()
m = nn.Linear(10, 10)
def hook_with_module(*args, **kwargs):
pass
m._register_load_state_dict_pre_hook(hook_with_module, True)
weak_m = weakref.ref(m)
del m
self.assertEqual(weak_m(), None)
finally:
gc.enable()
def test_load_state_dict_module_pre_hook(self): def test_load_state_dict_module_pre_hook(self):
hook_called = 0 hook_called = 0

View File

@ -1,7 +1,6 @@
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
import itertools import itertools
import warnings import warnings
import weakref
import functools import functools
import torch import torch
@ -39,20 +38,6 @@ def _addindent(s_, numSpaces):
return s return s
def _wrap_hook(hook, module):
weak_module = weakref.ref(module)
@functools.wraps(hook)
def inner(*args, **kwargs):
module = weak_module()
if module is None:
raise RuntimeError("You are trying to call hook of a dead object!")
else:
return hook(module, *args, **kwargs)
return inner
r"""This tracks hooks common to all modules that are executed before/after r"""This tracks hooks common to all modules that are executed before/after
calling forward and backward. This is global state used for debugging/profiling calling forward and backward. This is global state used for debugging/profiling
purposes""" purposes"""
@ -1181,7 +1166,8 @@ class Module:
grad_fn = var.grad_fn grad_fn = var.grad_fn
if grad_fn is not None: if grad_fn is not None:
for hook in non_full_backward_hooks: for hook in non_full_backward_hooks:
wrapper = _wrap_hook(hook, self) wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper) grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn) self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
@ -1416,7 +1402,7 @@ class Module:
""" """
handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
if with_module: if with_module:
hook = _wrap_hook(hook, self) hook = functools.partial(hook, self)
self._load_state_dict_pre_hooks[handle.id] = hook self._load_state_dict_pre_hooks[handle.id] = hook
return handle return handle