[easy] Move state_dict hooks tests to test_module_hooks and decorate tests that call load_state_dict with swap (#126906)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126906
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki 2024-06-10 11:12:32 -07:00 committed by PyTorch MergeBot
parent 58083ffb10
commit a2d4fea872
3 changed files with 89 additions and 81 deletions

View File

@ -3,14 +3,12 @@ import re
import unittest
from copy import deepcopy
from itertools import product
from tempfile import NamedTemporaryFile
import torch
import torch.nn as nn
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_WINDOWS,
parametrize,
run_tests,
skipIfCrossRef,
@ -206,33 +204,6 @@ class TestLoadStateDict(NNTestCase):
model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True)
model.load_state_dict(model.state_dict(), strict=True)
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
@swap([True, False])
def test_register_state_dict_pre_hook_backward_compat(self):
called = False
def my_state_dict_pre_hook(*args, **kwargs):
nonlocal called
called = True
m = nn.Linear(1, 1)
self.assertTrue(hasattr(m, "_state_dict_pre_hooks"))
delattr(m, "_state_dict_pre_hooks")
# Save and load, ensure we can still call state_dict
# without running into issues.
with NamedTemporaryFile() as f:
# Note that torch.save / torch.load is not recommended
# to save / load modules.
torch.save(m, f.name)
m = torch.load(f.name)
# Ensure we can run state_dict without issues
_ = m.state_dict()
self.assertFalse(called)
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
_ = m.state_dict()
self.assertTrue(called)
# fails swapping as LSTM installs weak references on the parameters
@swap([False])
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")

View File

@ -21,6 +21,7 @@ from torch.testing._internal.common_utils import (
parametrize as parametrize_test,
run_tests,
skipIfTorchDynamo,
swap,
TestCase,
)
@ -549,6 +550,7 @@ def _hook_to_pickle(*args, **kwargs):
class TestStateDictHooks(TestCase):
@swap([True, False])
def test_load_state_dict_pre_hook(self):
m = nn.Linear(10, 10)
m_state_dict = m.state_dict()
@ -613,6 +615,7 @@ class TestStateDictHooks(TestCase):
m._register_load_state_dict_pre_hook(_hook_to_pickle, True)
pickle.loads(pickle.dumps(m))
@swap([True, False])
def test_load_state_dict_module_pre_hook(self):
hook_called = 0
@ -686,6 +689,7 @@ class TestStateDictHooks(TestCase):
m.load_state_dict(state_dict)
self.assertEqual(2, hook_called)
@swap([True, False])
def test_load_state_dict_post_hook(self):
hook_called = 0
@ -743,6 +747,7 @@ class TestStateDictHooks(TestCase):
self.assertEqual([], ret.unexpected_keys)
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
@swap([True, False])
def test_load_state_dict_post_hook_backward_compatibility(self):
def my_post_load_hook(mod, _):
nonlocal called
@ -771,6 +776,89 @@ class TestStateDictHooks(TestCase):
m.load_state_dict(sd)
self.assertTrue(called)
def _test_register_state_dict_pre_hook(self, model, submodule):
_state_dict_prefix = "foo."
state_dict_pre_hook_count = 0
keep_var_setting = False
def my_state_dict_pre_hook(module, prefix, keep_vars):
self.assertEqual(keep_vars, keep_var_setting)
nonlocal state_dict_pre_hook_count
state_dict_pre_hook_count += 1
self.assertTrue(prefix.startswith(_state_dict_prefix))
model.register_state_dict_pre_hook(my_state_dict_pre_hook)
# Test to ensure submodules run the hook as well.
submodule.register_state_dict_pre_hook(my_state_dict_pre_hook)
def check_results(model):
nonlocal state_dict_pre_hook_count, keep_var_setting
for keep_var_setting in [True, False]:
_ = model.state_dict(
prefix=_state_dict_prefix, keep_vars=keep_var_setting
)
self.assertEqual(2, state_dict_pre_hook_count)
state_dict_pre_hook_count = 0
# Test state dict works as expected after model construction
check_results(model)
# Test state dict works as expected after forward
model(torch.ones(10, 3))
check_results(model)
def test_register_state_dict_pre_hook(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Sequential(
nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)
)
def forward(self, x):
return self.a(x)
mod = MyModule()
self._test_register_state_dict_pre_hook(mod, mod.a)
def test_register_state_dict_pre_hook_lazy_module(self):
class MyLazyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.LazyLinear(8)
self.layer2 = nn.LazyLinear(5)
def forward(self, x):
return self.layer2(self.layer1(x))
mod = MyLazyModule()
self._test_register_state_dict_pre_hook(mod, mod.layer1)
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
def test_register_state_dict_pre_hook_backward_compat(self):
called = False
def my_state_dict_pre_hook(*args, **kwargs):
nonlocal called
called = True
m = nn.Linear(1, 1)
self.assertTrue(hasattr(m, "_state_dict_pre_hooks"))
delattr(m, "_state_dict_pre_hooks")
# Save and load, ensure we can still call state_dict
# without running into issues.
with NamedTemporaryFile() as f:
# Note that torch.save / torch.load is not recommended
# to save / load modules.
torch.save(m, f.name)
m = torch.load(f.name)
# Ensure we can run state_dict without issues
_ = m.state_dict()
self.assertFalse(called)
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
_ = m.state_dict()
self.assertTrue(called)
class TestModuleGlobalHooks(TestCase):
def tearDown(self):
@ -1553,6 +1641,7 @@ class TestModuleHookNN(NNTestCase):
instantiate_parametrized_tests(TestModuleHooks)
instantiate_parametrized_tests(TestStateDictHooks)
if __name__ == "__main__":
run_tests()

View File

@ -2282,58 +2282,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
# Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545
self.assertNotWarn(lambda: l.state_dict(destination=dict()), "Should not warn kwarg destination w/o _metadata")
def _test_register_state_dict_pre_hook(self, model, submodule):
_state_dict_prefix = "foo."
state_dict_pre_hook_count = 0
keep_var_setting = False
def my_state_dict_pre_hook(module, prefix, keep_vars):
self.assertEqual(keep_vars, keep_var_setting)
nonlocal state_dict_pre_hook_count
state_dict_pre_hook_count += 1
self.assertTrue(prefix.startswith(_state_dict_prefix))
model.register_state_dict_pre_hook(my_state_dict_pre_hook)
# Test to ensure submodules run the hook as well.
submodule.register_state_dict_pre_hook(my_state_dict_pre_hook)
def check_results(model):
nonlocal state_dict_pre_hook_count, keep_var_setting
for keep_var_setting in [True, False]:
_ = model.state_dict(prefix=_state_dict_prefix, keep_vars=keep_var_setting)
self.assertEqual(2, state_dict_pre_hook_count)
state_dict_pre_hook_count = 0
# Test state dict works as expected after model construction
check_results(model)
# Test state dict works as expected after forward
model(torch.ones(10, 3))
check_results(model)
def test_register_state_dict_pre_hook(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3))
def forward(self, x):
return self.a(x)
mod = MyModule()
self._test_register_state_dict_pre_hook(mod, mod.a)
def test_register_state_dict_pre_hook_lazy_module(self):
class MyLazyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.LazyLinear(8)
self.layer2 = nn.LazyLinear(5)
def forward(self, x):
return self.layer2(self.layer1(x))
mod = MyLazyModule()
self._test_register_state_dict_pre_hook(mod, mod.layer1)
def test_extra_state(self):
class SubModule(torch.nn.Module):