mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move nn.Module.load_state_dict tests from test_nn.py to separate file (#118028)
Move these tests out so in https://github.com/pytorch/pytorch/pull/117913 where we can to run these tests with both `torch.nn.utils.set_swap_module_params_on_conversion({True/False})` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118028 Approved by: https://github.com/albanD
This commit is contained in:
parent
71655bccbe
commit
b92819a039
|
|
@ -1353,6 +1353,7 @@ exclude_patterns = [
|
|||
'test/nn/test_embedding.py',
|
||||
'test/nn/test_init.py',
|
||||
'test/nn/test_lazy_modules.py',
|
||||
'test/nn/test_load_state_dict.py',
|
||||
'test/nn/test_module_hooks.py',
|
||||
'test/nn/test_multihead_attention.py',
|
||||
'test/nn/test_packed_sequence.py',
|
||||
|
|
|
|||
369
test/nn/test_load_state_dict.py
Normal file
369
test/nn/test_load_state_dict.py
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
# Owner(s): ["module: nn"]
|
||||
from copy import deepcopy
|
||||
from tempfile import NamedTemporaryFile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import TestCase, \
|
||||
TEST_NUMPY, IS_WINDOWS, skipIfTorchDynamo, instantiate_parametrized_tests, \
|
||||
run_tests
|
||||
|
||||
if TEST_NUMPY:
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestLoadStateDict(NNTestCase):
|
||||
_do_cuda_memory_leak_check = True
|
||||
_do_cuda_non_default_stream = True
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
|
||||
def test_load_state_dict_invalid(self):
|
||||
m = torch.nn.Linear(2, 2, bias=False)
|
||||
|
||||
state_dict = {'weight': np.random.randn(2, 2)}
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"expected torch.Tensor or Tensor-like object from checkpoint but received"):
|
||||
m.load_state_dict(state_dict)
|
||||
|
||||
state_dict = {'weight': ((1., 1.), (2., 2.))}
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"expected torch.Tensor or Tensor-like object from checkpoint but received"):
|
||||
m.load_state_dict(state_dict)
|
||||
|
||||
def test_load_state_dict_type(self):
|
||||
m = nn.Module()
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Expected state_dict to be dict-like, got"):
|
||||
m.load_state_dict("")
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Expected state_dict to be dict-like, got"):
|
||||
m.load_state_dict(2)
|
||||
|
||||
def test_load_state_dict(self):
|
||||
l = nn.Linear(5, 5)
|
||||
block = nn.Module()
|
||||
block.conv1 = nn.Conv2d(3, 3, 3, bias=True)
|
||||
block.conv2 = nn.Conv2d(3, 3, 3, bias=False)
|
||||
net = nn.Module()
|
||||
net.linear1 = l
|
||||
net.linear2 = l
|
||||
net.bn = nn.BatchNorm2d(2)
|
||||
net.block = block
|
||||
net.add_module('empty', None)
|
||||
conv1_bias_dtype = block.conv1.bias.dtype
|
||||
|
||||
state_dict = net.state_dict()
|
||||
state_dict.update({
|
||||
'linear1.weight': torch.ones(5, 5),
|
||||
'block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype),
|
||||
'bn.running_mean': torch.randn(2),
|
||||
})
|
||||
# Also test if a DDP state_dict can be loaded from a local model.
|
||||
ddp_state_dict = net.state_dict()
|
||||
ddp_state_dict.update({
|
||||
'module.linear1.weight': torch.ones(5, 5),
|
||||
'module.block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype),
|
||||
'module.bn.running_mean': torch.randn(2),
|
||||
})
|
||||
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(ddp_state_dict, 'module.')
|
||||
for sd in [state_dict, ddp_state_dict]:
|
||||
incompatible_keys = net.load_state_dict(sd)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 0)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
|
||||
self.assertNotIn('Incompatible', str(incompatible_keys))
|
||||
self.assertEqual(net.linear1.weight, sd['linear1.weight'])
|
||||
self.assertEqual(net.block.conv1.bias, sd['block.conv1.bias'])
|
||||
self.assertEqual(net.bn.running_mean, sd['bn.running_mean'])
|
||||
|
||||
state_dict = net.state_dict()
|
||||
state_dict.update({'extra': torch.ones(5)})
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
incompatible_keys = net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 0)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
|
||||
self.assertIn('extra', incompatible_keys.unexpected_keys)
|
||||
self.assertIn('Incompatible', str(incompatible_keys))
|
||||
|
||||
state_dict = net.state_dict()
|
||||
state_dict.update({'extra.param': torch.ones(5)})
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
incompatible_keys = net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 0)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
|
||||
self.assertIn('extra.param', incompatible_keys.unexpected_keys)
|
||||
|
||||
state_dict = net.state_dict()
|
||||
del state_dict['linear1.weight']
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
incompatible_keys = net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 1)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
|
||||
self.assertIn('linear1.weight', incompatible_keys.missing_keys)
|
||||
state_dict.update({'extra.param': torch.ones(5)})
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
incompatible_keys = net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 1)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
|
||||
self.assertIn('linear1.weight', incompatible_keys.missing_keys)
|
||||
self.assertIn('extra.param', incompatible_keys.unexpected_keys)
|
||||
|
||||
state_dict = net.state_dict()
|
||||
state_dict.update({'bn.running_mean': torch.rand(14, 4)}) # wrong size
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict, strict=False))
|
||||
|
||||
state_dict = net.state_dict()
|
||||
old_state_dict = deepcopy(state_dict)
|
||||
state_dict = {
|
||||
'linear1.weight': torch.ones(5, 5),
|
||||
'block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype),
|
||||
'bn.running_mean': torch.randn(2),
|
||||
'nonexistent_key': torch.rand(3)
|
||||
}
|
||||
net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(net.linear1.weight, state_dict['linear1.weight'])
|
||||
self.assertEqual(net.block.conv1.bias, state_dict['block.conv1.bias'])
|
||||
self.assertEqual(net.bn.running_mean, state_dict['bn.running_mean'])
|
||||
new_state_dict = net.state_dict()
|
||||
del old_state_dict['linear1.weight']
|
||||
del old_state_dict['block.conv1.bias']
|
||||
del old_state_dict['bn.running_mean']
|
||||
for k, v, in old_state_dict.items():
|
||||
self.assertTrue(v.equal(new_state_dict[k]))
|
||||
|
||||
def test_load_state_dict_BC(self):
|
||||
# BatchNormNd
|
||||
# Added num_batches_tracked buffer at version 2. For state dict with
|
||||
# earlier versions or no versions, it should provide default value of 0.
|
||||
bn = nn.BatchNorm2d(3)
|
||||
state_dict = bn.state_dict()
|
||||
del state_dict['num_batches_tracked']
|
||||
state_dict._metadata['']['version'] = 1 # version 1
|
||||
bn.load_state_dict(state_dict)
|
||||
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
|
||||
self.assertEqual(bn.num_batches_tracked.item(), 0)
|
||||
del state_dict._metadata['']['version'] # no version
|
||||
bn.load_state_dict(state_dict)
|
||||
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
|
||||
self.assertEqual(bn.num_batches_tracked.item(), 0)
|
||||
|
||||
def test_load_state_dict_child(self):
|
||||
base_module = nn.Linear(1, 1)
|
||||
model = base_module
|
||||
for _ in range(3):
|
||||
model = nn.Sequential(*[deepcopy(model) for _ in range(10)])
|
||||
|
||||
def hook_fn(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
module_state_dict = module.state_dict()
|
||||
self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys()))
|
||||
|
||||
model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True)
|
||||
model.load_state_dict(model.state_dict(), strict=True)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
|
||||
def test_register_state_dict_pre_hook_backward_compat(self):
|
||||
called = False
|
||||
|
||||
def my_state_dict_pre_hook(*args, **kwargs):
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
m = nn.Linear(1, 1)
|
||||
self.assertTrue(hasattr(m, '_state_dict_pre_hooks'))
|
||||
delattr(m, '_state_dict_pre_hooks')
|
||||
# Save and load, ensure we can still call state_dict
|
||||
# without running into issues.
|
||||
with NamedTemporaryFile() as f:
|
||||
# Note that torch.save / torch.load is not recommended
|
||||
# to save / load modules.
|
||||
torch.save(m, f.name)
|
||||
m = torch.load(f.name)
|
||||
|
||||
# Ensure we can run state_dict without issues
|
||||
_ = m.state_dict()
|
||||
self.assertFalse(called)
|
||||
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
|
||||
_ = m.state_dict()
|
||||
self.assertTrue(called)
|
||||
|
||||
# FIXME: doesn't fail locally, maybe remove
|
||||
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
|
||||
def test_load_state_dict_ref_cycle(self):
|
||||
# load_state_dict shouldn't cause a reference cycle involving Tensors
|
||||
import gc
|
||||
|
||||
m = torch.nn.LSTM(16, 16, bidirectional=True)
|
||||
|
||||
gc.collect()
|
||||
m.load_state_dict(deepcopy(m).state_dict())
|
||||
refcycles = gc.collect()
|
||||
|
||||
self.assertEqual(refcycles, 0)
|
||||
|
||||
def test_load_state_dict_custom(self):
|
||||
|
||||
class CustomState(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.ones(1))
|
||||
self.sub = torch.nn.Linear(5, 5)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
destination[prefix + "serialized"] = self.param.data + 1
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
# skip some of the error handling
|
||||
self.param.data.copy_(state_dict[prefix + "serialized"] - 1)
|
||||
|
||||
# use sequential to verify nesting
|
||||
m = nn.Sequential(CustomState())
|
||||
with torch.no_grad():
|
||||
m[0].param[0] = 10
|
||||
m[0].sub.weight[0, 0] = 555
|
||||
state_dict = m.state_dict()
|
||||
self.assertEqual(state_dict["0.serialized"].item(), 11)
|
||||
self.assertIn("0.sub.weight", state_dict)
|
||||
self.assertNotIn("0.param", state_dict)
|
||||
del m
|
||||
mm = nn.Sequential(CustomState())
|
||||
self.assertEqual(mm[0].param[0].item(), 1)
|
||||
mm.load_state_dict(state_dict)
|
||||
self.assertEqual(mm[0].param[0].item(), 10)
|
||||
self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
|
||||
|
||||
def test_load_state_dict_assign_meta(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.bn(self.fc1(input))
|
||||
|
||||
net = MyModule()
|
||||
state_dict = net.state_dict(keep_vars=True)
|
||||
|
||||
with torch.device('meta'):
|
||||
net_meta = MyModule()
|
||||
|
||||
net_meta.load_state_dict(state_dict, assign=True)
|
||||
|
||||
# Make sure parameters and persistent buffers were assigned
|
||||
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
|
||||
for key in state_dict.keys():
|
||||
if isinstance(state_dict[key], torch.nn.Parameter):
|
||||
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
|
||||
|
||||
# Make sure that ordering of parameters and buffers is preserved
|
||||
net_named_parameters = net.named_parameters()
|
||||
net_named_buffers = net.named_buffers()
|
||||
net_meta_named_parameters = net_meta.named_parameters()
|
||||
net_meta_named_buffers = net_meta.named_buffers()
|
||||
|
||||
for p1, p2 in zip(net_named_parameters, net_meta_named_parameters):
|
||||
n1, _ = p1
|
||||
n2, _ = p2
|
||||
self.assertEqual(n1, n2)
|
||||
|
||||
for p1, p2 in zip(net_named_buffers, net_meta_named_buffers):
|
||||
n1, _ = p1
|
||||
n2, _ = p2
|
||||
self.assertEqual(n1, n2)
|
||||
|
||||
# Make sure outputs are the same
|
||||
t = torch.randn(4, 3)
|
||||
out_net = net(t)
|
||||
out_net_meta = net_meta(t.clone())
|
||||
|
||||
self.assertEqual(out_net, out_net_meta)
|
||||
|
||||
def test_load_state_dict_assign_with_optimizer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.bn(self.fc1(input))
|
||||
|
||||
net = MyModule()
|
||||
opt = torch.optim.Adam(net.parameters(), lr=1000)
|
||||
x = torch.randn(4, 3)
|
||||
num_iters = 3
|
||||
|
||||
for i in range(num_iters):
|
||||
opt.zero_grad()
|
||||
out = net(x)
|
||||
out.sum().backward()
|
||||
opt.step()
|
||||
|
||||
opt_state_dict = deepcopy(opt.state_dict())
|
||||
net_state_dict = deepcopy(net.state_dict())
|
||||
|
||||
with torch.device('meta'):
|
||||
net_meta = MyModule()
|
||||
|
||||
net_meta.load_state_dict(net_state_dict, assign=True)
|
||||
# must create optimizer only after loading state_dict when assign=True
|
||||
opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000)
|
||||
opt2.load_state_dict(opt_state_dict)
|
||||
|
||||
y = x.clone()
|
||||
for i in range(num_iters):
|
||||
opt.zero_grad()
|
||||
out = net(x)
|
||||
out.sum().backward()
|
||||
opt.step()
|
||||
|
||||
opt2.zero_grad()
|
||||
out2 = net_meta(y)
|
||||
out2.sum().backward()
|
||||
opt2.step()
|
||||
|
||||
self.assertEqual(opt.state_dict(), opt2.state_dict())
|
||||
self.assertEqual(net.state_dict(), net_meta.state_dict())
|
||||
|
||||
def test_load_state_dict_assign_shape_stride(self):
|
||||
# Assigned tensor is allowed to have different properties than initial
|
||||
# tensor except for shape
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.bn(self.fc1(input))
|
||||
|
||||
net = MyModule()
|
||||
state_dict = net.state_dict()
|
||||
# loading should be ok if stride is different
|
||||
state_dict['fc1.weight'] = torch.randn(3, 5).transpose(0, 1)
|
||||
net2 = MyModule()
|
||||
net2.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
state_dict['fc1.weight'] = torch.randn(2, 4)
|
||||
with self.assertRaisesRegex(RuntimeError, "size mismatch for fc1.weight: copying a param with shape"):
|
||||
net2.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
def test_load_state_dict_warn_assign(self):
|
||||
with torch.device('meta'):
|
||||
m = torch.nn.Linear(3, 5)
|
||||
state_dict = m.state_dict()
|
||||
state_dict['weight'] = torch.empty_like(state_dict['weight'], device='cpu')
|
||||
with self.assertWarnsRegex(UserWarning, "for weight: copying from a non-meta parameter in the checkpoint to a meta"):
|
||||
m.load_state_dict(state_dict)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestLoadStateDict)
|
||||
|
||||
if __name__ == '__main__':
|
||||
TestCase._default_dtype_check_enabled = True
|
||||
run_tests()
|
||||
345
test/test_nn.py
345
test/test_nn.py
|
|
@ -13,7 +13,6 @@ from copy import deepcopy
|
|||
from itertools import product
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
from tempfile import NamedTemporaryFile
|
||||
from unittest import SkipTest
|
||||
|
||||
import torch
|
||||
|
|
@ -35,7 +34,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te
|
|||
download_file, get_function_arglist, load_tests, skipIfMps, \
|
||||
IS_PPC, \
|
||||
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
|
||||
skipIfTorchDynamo, IS_WINDOWS, gcIfJetson, set_default_dtype
|
||||
skipIfTorchDynamo, gcIfJetson, set_default_dtype
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION, \
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
||||
|
|
@ -630,30 +629,6 @@ class TestNN(NNTestCase):
|
|||
self.assertTrue(len(list(m.buffers())) == 0)
|
||||
self.assertTrue(len(m.state_dict()) == 1)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
|
||||
def test_load_state_dict_invalid(self):
|
||||
m = torch.nn.Linear(2, 2, bias=False)
|
||||
|
||||
state_dict = {'weight': np.random.randn(2, 2)}
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"expected torch.Tensor or Tensor-like object from checkpoint but received"):
|
||||
m.load_state_dict(state_dict)
|
||||
|
||||
state_dict = {'weight': ((1., 1.), (2., 2.))}
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"expected torch.Tensor or Tensor-like object from checkpoint but received"):
|
||||
m.load_state_dict(state_dict)
|
||||
|
||||
def test_load_state_dict_type(self):
|
||||
m = nn.Module()
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Expected state_dict to be dict-like, got"):
|
||||
m.load_state_dict("")
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
"Expected state_dict to be dict-like, got"):
|
||||
m.load_state_dict(2)
|
||||
|
||||
def test_buffer_not_persistent_load(self):
|
||||
m = nn.Module()
|
||||
m.register_buffer('buf', torch.rand(5), persistent=False)
|
||||
|
|
@ -2277,153 +2252,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
# Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545
|
||||
self.assertNotWarn(lambda: l.state_dict(destination=dict()), "Should not warn kwarg destination w/o _metadata")
|
||||
|
||||
def test_load_state_dict(self):
|
||||
l = nn.Linear(5, 5)
|
||||
block = nn.Module()
|
||||
block.conv1 = nn.Conv2d(3, 3, 3, bias=True)
|
||||
block.conv2 = nn.Conv2d(3, 3, 3, bias=False)
|
||||
net = nn.Module()
|
||||
net.linear1 = l
|
||||
net.linear2 = l
|
||||
net.bn = nn.BatchNorm2d(2)
|
||||
net.block = block
|
||||
net.add_module('empty', None)
|
||||
conv1_bias_dtype = block.conv1.bias.dtype
|
||||
|
||||
state_dict = net.state_dict()
|
||||
state_dict.update({
|
||||
'linear1.weight': torch.ones(5, 5),
|
||||
'block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype),
|
||||
'bn.running_mean': torch.randn(2),
|
||||
})
|
||||
# Also test if a DDP state_dict can be loaded from a local model.
|
||||
ddp_state_dict = net.state_dict()
|
||||
ddp_state_dict.update({
|
||||
'module.linear1.weight': torch.ones(5, 5),
|
||||
'module.block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype),
|
||||
'module.bn.running_mean': torch.randn(2),
|
||||
})
|
||||
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(ddp_state_dict, 'module.')
|
||||
for sd in [state_dict, ddp_state_dict]:
|
||||
incompatible_keys = net.load_state_dict(sd)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 0)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
|
||||
self.assertNotIn('Incompatible', str(incompatible_keys))
|
||||
self.assertEqual(net.linear1.weight, sd['linear1.weight'])
|
||||
self.assertEqual(net.block.conv1.bias, sd['block.conv1.bias'])
|
||||
self.assertEqual(net.bn.running_mean, sd['bn.running_mean'])
|
||||
|
||||
state_dict = net.state_dict()
|
||||
state_dict.update({'extra': torch.ones(5)})
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
incompatible_keys = net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 0)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
|
||||
self.assertIn('extra', incompatible_keys.unexpected_keys)
|
||||
self.assertIn('Incompatible', str(incompatible_keys))
|
||||
|
||||
state_dict = net.state_dict()
|
||||
state_dict.update({'extra.param': torch.ones(5)})
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
incompatible_keys = net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 0)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
|
||||
self.assertIn('extra.param', incompatible_keys.unexpected_keys)
|
||||
|
||||
state_dict = net.state_dict()
|
||||
del state_dict['linear1.weight']
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
incompatible_keys = net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 1)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
|
||||
self.assertIn('linear1.weight', incompatible_keys.missing_keys)
|
||||
state_dict.update({'extra.param': torch.ones(5)})
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
incompatible_keys = net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(len(incompatible_keys.missing_keys), 1)
|
||||
self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
|
||||
self.assertIn('linear1.weight', incompatible_keys.missing_keys)
|
||||
self.assertIn('extra.param', incompatible_keys.unexpected_keys)
|
||||
|
||||
state_dict = net.state_dict()
|
||||
state_dict.update({'bn.running_mean': torch.rand(14, 4)}) # wrong size
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
|
||||
self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict, strict=False))
|
||||
|
||||
state_dict = net.state_dict()
|
||||
old_state_dict = deepcopy(state_dict)
|
||||
state_dict = {
|
||||
'linear1.weight': torch.ones(5, 5),
|
||||
'block.conv1.bias': torch.arange(1, 4, dtype=conv1_bias_dtype),
|
||||
'bn.running_mean': torch.randn(2),
|
||||
'nonexistent_key': torch.rand(3)
|
||||
}
|
||||
net.load_state_dict(state_dict, strict=False)
|
||||
self.assertEqual(net.linear1.weight, state_dict['linear1.weight'])
|
||||
self.assertEqual(net.block.conv1.bias, state_dict['block.conv1.bias'])
|
||||
self.assertEqual(net.bn.running_mean, state_dict['bn.running_mean'])
|
||||
new_state_dict = net.state_dict()
|
||||
del old_state_dict['linear1.weight']
|
||||
del old_state_dict['block.conv1.bias']
|
||||
del old_state_dict['bn.running_mean']
|
||||
for k, v, in old_state_dict.items():
|
||||
self.assertTrue(v.equal(new_state_dict[k]))
|
||||
|
||||
def test_load_state_dict_BC(self):
|
||||
# BatchNormNd
|
||||
# Added num_batches_tracked buffer at version 2. For state dict with
|
||||
# earlier versions or no versions, it should provide default value of 0.
|
||||
bn = nn.BatchNorm2d(3)
|
||||
state_dict = bn.state_dict()
|
||||
del state_dict['num_batches_tracked']
|
||||
state_dict._metadata['']['version'] = 1 # version 1
|
||||
bn.load_state_dict(state_dict)
|
||||
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
|
||||
self.assertEqual(bn.num_batches_tracked.item(), 0)
|
||||
del state_dict._metadata['']['version'] # no version
|
||||
bn.load_state_dict(state_dict)
|
||||
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
|
||||
self.assertEqual(bn.num_batches_tracked.item(), 0)
|
||||
|
||||
def test_load_state_dict_child(self):
|
||||
base_module = nn.Linear(1, 1)
|
||||
model = base_module
|
||||
for _ in range(3):
|
||||
model = nn.Sequential(*[deepcopy(model) for _ in range(10)])
|
||||
|
||||
def hook_fn(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
module_state_dict = module.state_dict()
|
||||
self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys()))
|
||||
|
||||
model[0][0]._register_load_state_dict_pre_hook(hook_fn, with_module=True)
|
||||
model.load_state_dict(model.state_dict(), strict=True)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Tempfile permission issue on windows")
|
||||
def test_register_state_dict_pre_hook_backward_compat(self):
|
||||
called = False
|
||||
|
||||
def my_state_dict_pre_hook(*args, **kwargs):
|
||||
nonlocal called
|
||||
called = True
|
||||
|
||||
m = nn.Linear(1, 1)
|
||||
self.assertTrue(hasattr(m, '_state_dict_pre_hooks'))
|
||||
delattr(m, '_state_dict_pre_hooks')
|
||||
# Save and load, ensure we can still call state_dict
|
||||
# without running into issues.
|
||||
with NamedTemporaryFile() as f:
|
||||
# Note that torch.save / torch.load is not recommended
|
||||
# to save / load modules.
|
||||
torch.save(m, f.name)
|
||||
m = torch.load(f.name)
|
||||
|
||||
# Ensure we can run state_dict without issues
|
||||
_ = m.state_dict()
|
||||
self.assertFalse(called)
|
||||
m.register_state_dict_pre_hook(my_state_dict_pre_hook)
|
||||
_ = m.state_dict()
|
||||
self.assertTrue(called)
|
||||
|
||||
def _test_register_state_dict_pre_hook(self, model, submodule):
|
||||
_state_dict_prefix = "foo."
|
||||
state_dict_pre_hook_count = 0
|
||||
|
|
@ -2476,52 +2304,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
mod = MyLazyModule()
|
||||
self._test_register_state_dict_pre_hook(mod, mod.layer1)
|
||||
|
||||
@skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
|
||||
def test_load_state_dict_ref_cycle(self):
|
||||
# load_state_dict shouldn't cause a reference cycle involving Tensors
|
||||
import gc
|
||||
|
||||
m = torch.nn.LSTM(16, 16, bidirectional=True)
|
||||
|
||||
gc.collect()
|
||||
m.load_state_dict(deepcopy(m).state_dict())
|
||||
refcycles = gc.collect()
|
||||
|
||||
self.assertEqual(refcycles, 0)
|
||||
|
||||
def test_load_state_dict_custom(self):
|
||||
|
||||
class CustomState(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.ones(1))
|
||||
self.sub = torch.nn.Linear(5, 5)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
destination[prefix + "serialized"] = self.param.data + 1
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
# skip some of the error handling
|
||||
self.param.data.copy_(state_dict[prefix + "serialized"] - 1)
|
||||
|
||||
# use sequential to verify nesting
|
||||
m = nn.Sequential(CustomState())
|
||||
with torch.no_grad():
|
||||
m[0].param[0] = 10
|
||||
m[0].sub.weight[0, 0] = 555
|
||||
state_dict = m.state_dict()
|
||||
self.assertEqual(state_dict["0.serialized"].item(), 11)
|
||||
self.assertIn("0.sub.weight", state_dict)
|
||||
self.assertNotIn("0.param", state_dict)
|
||||
del m
|
||||
mm = nn.Sequential(CustomState())
|
||||
self.assertEqual(mm[0].param[0].item(), 1)
|
||||
mm.load_state_dict(state_dict)
|
||||
self.assertEqual(mm[0].param[0].item(), 10)
|
||||
self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
|
||||
|
||||
def test_extra_state(self):
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
|
|
@ -2580,131 +2362,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
self.assertEqual(m.state_dict(), m2.state_dict())
|
||||
self.assertEqual(m.foo, m2.foo)
|
||||
|
||||
def test_load_state_dict_assign_meta(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.bn(self.fc1(input))
|
||||
|
||||
net = MyModule()
|
||||
state_dict = net.state_dict(keep_vars=True)
|
||||
|
||||
with torch.device('meta'):
|
||||
net_meta = MyModule()
|
||||
|
||||
net_meta.load_state_dict(state_dict, assign=True)
|
||||
|
||||
# Make sure parameters and persistent buffers were assigned
|
||||
net_meta_state_dict = net_meta.state_dict(keep_vars=True)
|
||||
for key in state_dict.keys():
|
||||
if isinstance(state_dict[key], torch.nn.Parameter):
|
||||
self.assertTrue(state_dict[key] is net_meta_state_dict[key])
|
||||
|
||||
# Make sure that ordering of parameters and buffers is preserved
|
||||
net_named_parameters = net.named_parameters()
|
||||
net_named_buffers = net.named_buffers()
|
||||
net_meta_named_parameters = net_meta.named_parameters()
|
||||
net_meta_named_buffers = net_meta.named_buffers()
|
||||
|
||||
for p1, p2 in zip(net_named_parameters, net_meta_named_parameters):
|
||||
n1, _ = p1
|
||||
n2, _ = p2
|
||||
self.assertEqual(n1, n2)
|
||||
|
||||
for p1, p2 in zip(net_named_buffers, net_meta_named_buffers):
|
||||
n1, _ = p1
|
||||
n2, _ = p2
|
||||
self.assertEqual(n1, n2)
|
||||
|
||||
# Make sure outputs are the same
|
||||
t = torch.randn(4, 3)
|
||||
out_net = net(t)
|
||||
out_net_meta = net_meta(t.clone())
|
||||
|
||||
self.assertEqual(out_net, out_net_meta)
|
||||
|
||||
def test_load_state_dict_assign_with_optimizer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.bn(self.fc1(input))
|
||||
|
||||
net = MyModule()
|
||||
opt = torch.optim.Adam(net.parameters(), lr=1000)
|
||||
x = torch.randn(4, 3)
|
||||
num_iters = 3
|
||||
|
||||
for i in range(num_iters):
|
||||
opt.zero_grad()
|
||||
out = net(x)
|
||||
out.sum().backward()
|
||||
opt.step()
|
||||
|
||||
opt_state_dict = deepcopy(opt.state_dict())
|
||||
net_state_dict = deepcopy(net.state_dict())
|
||||
|
||||
with torch.device('meta'):
|
||||
net_meta = MyModule()
|
||||
|
||||
net_meta.load_state_dict(net_state_dict, assign=True)
|
||||
# must create optimizer only after loading state_dict when assign=True
|
||||
opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000)
|
||||
opt2.load_state_dict(opt_state_dict)
|
||||
|
||||
y = x.clone()
|
||||
for i in range(num_iters):
|
||||
opt.zero_grad()
|
||||
out = net(x)
|
||||
out.sum().backward()
|
||||
opt.step()
|
||||
|
||||
opt2.zero_grad()
|
||||
out2 = net_meta(y)
|
||||
out2.sum().backward()
|
||||
opt2.step()
|
||||
|
||||
self.assertEqual(opt.state_dict(), opt2.state_dict())
|
||||
self.assertEqual(net.state_dict(), net_meta.state_dict())
|
||||
|
||||
def test_load_state_dict_assign_shape_stride(self):
|
||||
# Assigned tensor is allowed to have different properties than initial
|
||||
# tensor except for shape
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.bn(self.fc1(input))
|
||||
|
||||
net = MyModule()
|
||||
state_dict = net.state_dict()
|
||||
# loading should be ok if stride is different
|
||||
state_dict['fc1.weight'] = torch.randn(3, 5).transpose(0, 1)
|
||||
net2 = MyModule()
|
||||
net2.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
state_dict['fc1.weight'] = torch.randn(2, 4)
|
||||
with self.assertRaisesRegex(RuntimeError, "size mismatch for fc1.weight: copying a param with shape"):
|
||||
net2.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
def test_load_state_dict_warn_assign(self):
|
||||
with torch.device('meta'):
|
||||
m = torch.nn.Linear(3, 5)
|
||||
state_dict = m.state_dict()
|
||||
state_dict['weight'] = torch.empty_like(state_dict['weight'], device='cpu')
|
||||
with self.assertWarnsRegex(UserWarning, "for weight: copying from a non-meta parameter in the checkpoint to a meta"):
|
||||
m.load_state_dict(state_dict)
|
||||
|
||||
def test_extra_state_missing_set_extra_state(self):
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user