mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[easy] Move state_dict hooks tests to test_module_hooks and decorate tests that call load_state_dict with swap (#126906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126906 Approved by: https://github.com/albanD
This commit is contained in:
parent
58083ffb10
commit
a2d4fea872
|
|
@ -3,14 +3,12 @@ import re
|
|||
import unittest
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfCrossRef,
|
||||
|
|
@ -206,33 +204,6 @@ class TestLoadStateDict(NNTestCase):
|
|||
model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True)
|
||||
model.load_state_dict(model.state_dict(), strict=True)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
|
||||
@swap([True, False])
|
||||
def test_register_state_dict_pre_hook_backward_compat(self):
|
||||
called = False
|
||||
|
||||
def my_state_dict_pre_hook(*args, **kwargs):
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
m = nn.Linear(1, 1)
|
||||
self.assertTrue(hasattr(m, "_state_dict_pre_hooks"))
|
||||
delattr(m, "_state_dict_pre_hooks")
|
||||
# Save and load, ensure we can still call state_dict
|
||||
# without running into issues.
|
||||
with NamedTemporaryFile() as f:
|
||||
# Note that torch.save / torch.load is not recommended
|
||||
# to save / load modules.
|
||||
torch.save(m, f.name)
|
||||
m = torch.load(f.name)
|
||||
|
||||
# Ensure we can run state_dict without issues
|
||||
_ = m.state_dict()
|
||||
self.assertFalse(called)
|
||||
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
|
||||
_ = m.state_dict()
|
||||
self.assertTrue(called)
|
||||
|
||||
# fails swapping as LSTM installs weak references on the parameters
|
||||
@swap([False])
|
||||
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from torch.testing._internal.common_utils import (
|
|||
parametrize as parametrize_test,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
swap,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
|
@ -549,6 +550,7 @@ def _hook_to_pickle(*args, **kwargs):
|
|||
|
||||
|
||||
class TestStateDictHooks(TestCase):
|
||||
@swap([True, False])
|
||||
def test_load_state_dict_pre_hook(self):
|
||||
m = nn.Linear(10, 10)
|
||||
m_state_dict = m.state_dict()
|
||||
|
|
@ -613,6 +615,7 @@ class TestStateDictHooks(TestCase):
|
|||
m._register_load_state_dict_pre_hook(_hook_to_pickle, True)
|
||||
pickle.loads(pickle.dumps(m))
|
||||
|
||||
@swap([True, False])
|
||||
def test_load_state_dict_module_pre_hook(self):
|
||||
hook_called = 0
|
||||
|
||||
|
|
@ -686,6 +689,7 @@ class TestStateDictHooks(TestCase):
|
|||
m.load_state_dict(state_dict)
|
||||
self.assertEqual(2, hook_called)
|
||||
|
||||
@swap([True, False])
|
||||
def test_load_state_dict_post_hook(self):
|
||||
hook_called = 0
|
||||
|
||||
|
|
@ -743,6 +747,7 @@ class TestStateDictHooks(TestCase):
|
|||
self.assertEqual([], ret.unexpected_keys)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
|
||||
@swap([True, False])
|
||||
def test_load_state_dict_post_hook_backward_compatibility(self):
|
||||
def my_post_load_hook(mod, _):
|
||||
nonlocal called
|
||||
|
|
@ -771,6 +776,89 @@ class TestStateDictHooks(TestCase):
|
|||
m.load_state_dict(sd)
|
||||
self.assertTrue(called)
|
||||
|
||||
def _test_register_state_dict_pre_hook(self, model, submodule):
|
||||
_state_dict_prefix = "foo."
|
||||
state_dict_pre_hook_count = 0
|
||||
keep_var_setting = False
|
||||
|
||||
def my_state_dict_pre_hook(module, prefix, keep_vars):
|
||||
self.assertEqual(keep_vars, keep_var_setting)
|
||||
nonlocal state_dict_pre_hook_count
|
||||
state_dict_pre_hook_count += 1
|
||||
self.assertTrue(prefix.startswith(_state_dict_prefix))
|
||||
|
||||
model.register_state_dict_pre_hook(my_state_dict_pre_hook)
|
||||
# Test to ensure submodules run the hook as well.
|
||||
submodule.register_state_dict_pre_hook(my_state_dict_pre_hook)
|
||||
|
||||
def check_results(model):
|
||||
nonlocal state_dict_pre_hook_count, keep_var_setting
|
||||
for keep_var_setting in [True, False]:
|
||||
_ = model.state_dict(
|
||||
prefix=_state_dict_prefix, keep_vars=keep_var_setting
|
||||
)
|
||||
self.assertEqual(2, state_dict_pre_hook_count)
|
||||
state_dict_pre_hook_count = 0
|
||||
|
||||
# Test state dict works as expected after model construction
|
||||
check_results(model)
|
||||
# Test state dict works as expected after forward
|
||||
model(torch.ones(10, 3))
|
||||
check_results(model)
|
||||
|
||||
def test_register_state_dict_pre_hook(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = nn.Sequential(
|
||||
nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.a(x)
|
||||
|
||||
mod = MyModule()
|
||||
self._test_register_state_dict_pre_hook(mod, mod.a)
|
||||
|
||||
def test_register_state_dict_pre_hook_lazy_module(self):
|
||||
class MyLazyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer1 = nn.LazyLinear(8)
|
||||
self.layer2 = nn.LazyLinear(5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer2(self.layer1(x))
|
||||
|
||||
mod = MyLazyModule()
|
||||
self._test_register_state_dict_pre_hook(mod, mod.layer1)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
|
||||
def test_register_state_dict_pre_hook_backward_compat(self):
|
||||
called = False
|
||||
|
||||
def my_state_dict_pre_hook(*args, **kwargs):
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
m = nn.Linear(1, 1)
|
||||
self.assertTrue(hasattr(m, "_state_dict_pre_hooks"))
|
||||
delattr(m, "_state_dict_pre_hooks")
|
||||
# Save and load, ensure we can still call state_dict
|
||||
# without running into issues.
|
||||
with NamedTemporaryFile() as f:
|
||||
# Note that torch.save / torch.load is not recommended
|
||||
# to save / load modules.
|
||||
torch.save(m, f.name)
|
||||
m = torch.load(f.name)
|
||||
|
||||
# Ensure we can run state_dict without issues
|
||||
_ = m.state_dict()
|
||||
self.assertFalse(called)
|
||||
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
|
||||
_ = m.state_dict()
|
||||
self.assertTrue(called)
|
||||
|
||||
|
||||
class TestModuleGlobalHooks(TestCase):
|
||||
def tearDown(self):
|
||||
|
|
@ -1553,6 +1641,7 @@ class TestModuleHookNN(NNTestCase):
|
|||
|
||||
|
||||
instantiate_parametrized_tests(TestModuleHooks)
|
||||
instantiate_parametrized_tests(TestStateDictHooks)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2282,58 +2282,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
# Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545
|
||||
self.assertNotWarn(lambda: l.state_dict(destination=dict()), "Should not warn kwarg destination w/o _metadata")
|
||||
|
||||
def _test_register_state_dict_pre_hook(self, model, submodule):
|
||||
_state_dict_prefix = "foo."
|
||||
state_dict_pre_hook_count = 0
|
||||
keep_var_setting = False
|
||||
|
||||
def my_state_dict_pre_hook(module, prefix, keep_vars):
|
||||
self.assertEqual(keep_vars, keep_var_setting)
|
||||
nonlocal state_dict_pre_hook_count
|
||||
state_dict_pre_hook_count += 1
|
||||
self.assertTrue(prefix.startswith(_state_dict_prefix))
|
||||
|
||||
model.register_state_dict_pre_hook(my_state_dict_pre_hook)
|
||||
# Test to ensure submodules run the hook as well.
|
||||
submodule.register_state_dict_pre_hook(my_state_dict_pre_hook)
|
||||
|
||||
def check_results(model):
|
||||
nonlocal state_dict_pre_hook_count, keep_var_setting
|
||||
for keep_var_setting in [True, False]:
|
||||
_ = model.state_dict(prefix=_state_dict_prefix, keep_vars=keep_var_setting)
|
||||
self.assertEqual(2, state_dict_pre_hook_count)
|
||||
state_dict_pre_hook_count = 0
|
||||
# Test state dict works as expected after model construction
|
||||
check_results(model)
|
||||
# Test state dict works as expected after forward
|
||||
model(torch.ones(10, 3))
|
||||
check_results(model)
|
||||
|
||||
def test_register_state_dict_pre_hook(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3))
|
||||
|
||||
def forward(self, x):
|
||||
return self.a(x)
|
||||
|
||||
mod = MyModule()
|
||||
self._test_register_state_dict_pre_hook(mod, mod.a)
|
||||
|
||||
def test_register_state_dict_pre_hook_lazy_module(self):
|
||||
class MyLazyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer1 = nn.LazyLinear(8)
|
||||
self.layer2 = nn.LazyLinear(5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer2(self.layer1(x))
|
||||
|
||||
mod = MyLazyModule()
|
||||
self._test_register_state_dict_pre_hook(mod, mod.layer1)
|
||||
|
||||
def test_extra_state(self):
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user