mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR fixes several bugs, listed in priority: 1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed. 2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks 3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos 4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place. 5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected. The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device. Details for posterity: 4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct. ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={}, desc=default params=None, kwargs={'lr': 0.01}, desc=non-default lr params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad params=None, kwargs={'capturable': True}, desc=capturable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad . ---------------------------------------------------------------------- Ran 1 test in 19.229s OK ``` 5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct. ``` /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={'differentiable': False}, desc=default params=None, kwargs={'differentiable': True}, desc=default & differentiable params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable .params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused . ---------------------------------------------------------------------- Ran 2 tests in 11.112s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326 Approved by: https://github.com/mlazos
642 lines
29 KiB
Python
642 lines
29 KiB
Python
# Owner(s): ["module: nn"]
|
|
import unittest
|
|
import pickle
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.parameter import UninitializedParameter, UninitializedBuffer
|
|
from torch.nn import Parameter
|
|
from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings, TEST_PRIVATEUSE1
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
|
|
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
|
|
pass
|
|
|
|
|
|
class TestLazyModules(TestCase):
|
|
|
|
@suppress_warnings
|
|
def test_lazy_module_parameter(self):
|
|
module = LazyModule()
|
|
module.register_parameter('test_param', UninitializedParameter())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
state_dict = module.state_dict()
|
|
self.assertIsInstance(state_dict['test_param'], UninitializedParameter)
|
|
new_module = LazyModule()
|
|
# An error is raised when there is an attempt to replace an existing parameter
|
|
# with an uninitialized one
|
|
new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5)))
|
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
|
new_module.load_state_dict(state_dict)
|
|
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
|
|
new_module = LazyModule()
|
|
new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5)))
|
|
module.load_state_dict(new_module.state_dict())
|
|
self.assertEqual(module.test_param, torch.ones((5, 5)))
|
|
|
|
# Uninitialized parameters are left unchanged
|
|
module = LazyModule()
|
|
module.register_parameter('test_param', UninitializedParameter())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
|
|
new_module = LazyModule()
|
|
new_module.register_parameter('test_param', UninitializedParameter())
|
|
module.load_state_dict(new_module.state_dict())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
|
|
@suppress_warnings
|
|
def test_lazy_module_buffer(self):
|
|
module = LazyModule()
|
|
module.register_buffer('test_buffer', UninitializedBuffer())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
state_dict = module.state_dict()
|
|
self.assertIsInstance(state_dict['test_buffer'], UninitializedBuffer)
|
|
new_module = LazyModule()
|
|
# An error is raised when there is an attempt to replace an existing parameter
|
|
# with an uninitialized one
|
|
new_module.register_buffer('test_buffer', torch.ones(5, 5))
|
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
|
new_module.load_state_dict(state_dict)
|
|
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
|
|
new_module = LazyModule()
|
|
new_module.register_buffer('test_buffer', torch.ones(5, 5))
|
|
module.load_state_dict(new_module.state_dict())
|
|
self.assertEqual(module.test_buffer, torch.ones((5, 5)))
|
|
|
|
# Uninitialized parameters are left unchanged
|
|
module = LazyModule()
|
|
module.register_buffer('test_buffer', UninitializedBuffer())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
|
|
new_module = LazyModule()
|
|
new_module.register_buffer('test_buffer', UninitializedBuffer())
|
|
module.load_state_dict(new_module.state_dict())
|
|
module.load_state_dict(new_module.state_dict())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
|
|
@suppress_warnings
|
|
def test_lazy_module_jit_param(self):
|
|
module = LazyModule()
|
|
module.register_parameter('test_param', UninitializedParameter())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
|
|
torch.jit.script(module)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_module_jit_buffer(self):
|
|
module = LazyModule()
|
|
module.register_buffer('test_buffer', UninitializedBuffer())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
|
|
torch.jit.script(module)
|
|
|
|
@suppress_warnings
|
|
def test_lazy_share_memory_param(self):
|
|
module = LazyModule()
|
|
module.register_parameter('test_param', UninitializedParameter())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
|
|
module.share_memory()
|
|
|
|
@suppress_warnings
|
|
def test_lazy_share_memory_buffer(self):
|
|
module = LazyModule()
|
|
module.register_buffer('test_buffer', UninitializedBuffer())
|
|
self.assertTrue(module.has_uninitialized_params())
|
|
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
|
|
module.share_memory()
|
|
|
|
@suppress_warnings
|
|
def test_linear(self):
|
|
module = nn.LazyLinear(10)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
input = torch.ones(5, 5)
|
|
module(input)
|
|
self.assertIsInstance(module, nn.Linear)
|
|
self.assertNotIsInstance(module, nn.LazyLinear)
|
|
self.assertTrue(module.weight.shape == (10, 5))
|
|
self.assertTrue(module.bias.shape == (10,))
|
|
y = module(input)
|
|
self.assertTrue(torch.equal(torch.nn.functional.linear(input, module.weight, module.bias), y))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_linear_pickle(self):
|
|
module = nn.LazyLinear(10)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
module = pickle.loads(pickle.dumps(module))
|
|
self.assertIsInstance(module, nn.LazyLinear)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
input = torch.ones(5, 5)
|
|
module(input) # fully materialized
|
|
new_module = pickle.loads(pickle.dumps(module))
|
|
self.assertIsInstance(new_module, nn.Linear)
|
|
self.assertNotIsInstance(new_module, nn.LazyLinear)
|
|
self.assertTrue(new_module.weight.shape == (10, 5))
|
|
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
|
|
self.assertTrue(new_module.bias.shape == (10,))
|
|
self.assertNotIsInstance(new_module.bias, UninitializedParameter)
|
|
|
|
@suppress_warnings
|
|
def test_linear_state(self):
|
|
module = nn.Linear(5, 10)
|
|
lazy_module = nn.LazyLinear(10)
|
|
lazy_module.load_state_dict(module.state_dict())
|
|
# Parameters have been initialized but the module won't become a full
|
|
# Linear one until the first iteration. This is due to
|
|
# limitations on the state_dict loading logic
|
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
|
self.assertTrue(lazy_module.weight.shape == (10, 5))
|
|
self.assertTrue(lazy_module.bias.shape == (10,))
|
|
|
|
module = nn.Linear(5, 10)
|
|
lazy_module = nn.LazyLinear(10)
|
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
|
module.load_state_dict(lazy_module.state_dict())
|
|
|
|
def _check_lazy_conv(self, cls, lazy_cls, func, init_args, input_shape,
|
|
expected_weight_shape, expected_bias_shape):
|
|
module = lazy_cls(*init_args)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
if module.bias is not None:
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
input = torch.ones(*input_shape)
|
|
module(input)
|
|
self.assertIsInstance(module, cls)
|
|
self.assertNotIsInstance(module, lazy_cls)
|
|
self.assertEqual(module.weight.shape, expected_weight_shape)
|
|
if module.bias is not None:
|
|
self.assertEqual(module.bias.shape, expected_bias_shape)
|
|
y = module(input)
|
|
self.assertTrue(torch.equal(func(input, module.weight, module.bias), y))
|
|
|
|
def _check_lazy_conv_pickle(self, cls, lazy_cls, init_args, input_shape,
|
|
expected_weight_shape, expected_bias_shape):
|
|
module = lazy_cls(*init_args)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
if module.bias is not None:
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
module = pickle.loads(pickle.dumps(module))
|
|
self.assertIsInstance(module, lazy_cls)
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
if module.bias is not None:
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
input = torch.ones(*input_shape)
|
|
module(input) # fully materialized
|
|
new_module = pickle.loads(pickle.dumps(module))
|
|
self.assertIsInstance(new_module, cls)
|
|
self.assertNotIsInstance(new_module, lazy_cls)
|
|
self.assertEqual(new_module.weight.shape, expected_weight_shape)
|
|
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
|
|
if new_module.bias is not None:
|
|
self.assertEqual(new_module.bias.shape, expected_bias_shape)
|
|
self.assertNotIsInstance(new_module.bias, UninitializedParameter)
|
|
|
|
def _check_lazy_conv_state(self, gen_module, gen_lazy_module,
|
|
expected_weight_shape, expected_bias_shape):
|
|
module = gen_module()
|
|
lazy_module = gen_lazy_module()
|
|
lazy_module.load_state_dict(module.state_dict())
|
|
# Parameters have been initialized but the module won't become a full
|
|
# Conv one until the first iteration. This is due to
|
|
# limitations on the state_dict loading logic
|
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
|
self.assertEqual(lazy_module.weight.shape, expected_weight_shape)
|
|
if lazy_module.bias is not None:
|
|
self.assertEqual(lazy_module.bias.shape, expected_bias_shape)
|
|
|
|
module = gen_module()
|
|
lazy_module = gen_lazy_module()
|
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
|
module.load_state_dict(lazy_module.state_dict())
|
|
|
|
|
|
def test_lazy_pre_forward_hook(self):
|
|
"""
|
|
This test is to test whether lazymodule can register other pre-forward hook
|
|
functions successfully.
|
|
"""
|
|
class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
|
|
def initialize_parameters(self, input):
|
|
return None
|
|
|
|
def forward(self, input):
|
|
return input
|
|
|
|
def hook_function(module, input):
|
|
return input[0] + 1
|
|
|
|
module = TestModule()
|
|
module.register_forward_pre_hook(hook_function)
|
|
output = module(torch.zeros(2, 2))
|
|
self.assertEqual(output, torch.ones(2, 2))
|
|
|
|
def test_lazy_forward_hook(self):
|
|
"""
|
|
This test is to test whether lazymodule can register other forward hook
|
|
functions successfully.
|
|
"""
|
|
class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
|
|
def initialize_parameters(self, input):
|
|
return None
|
|
|
|
def forward(self, input):
|
|
return input
|
|
|
|
def hook_function(module, input, output):
|
|
return input[0] + 1
|
|
|
|
module = TestModule()
|
|
module.register_forward_hook(hook_function)
|
|
output = module(torch.zeros(2, 2))
|
|
self.assertEqual(output, torch.ones(2, 2))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv1d(self):
|
|
self._check_lazy_conv(nn.Conv1d, nn.LazyConv1d, torch.nn.functional.conv1d,
|
|
(32, 2), (192, 16, 50), (32, 16, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv1d_pickle(self):
|
|
self._check_lazy_conv_pickle(nn.Conv1d, nn.LazyConv1d, (32, 2), (192, 16, 50),
|
|
(32, 16, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv1d_state(self):
|
|
self._check_lazy_conv_state(lambda: nn.Conv1d(16, 32, 2),
|
|
lambda: nn.LazyConv1d(32, 2),
|
|
(32, 16, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv2d(self):
|
|
self._check_lazy_conv(nn.Conv2d, nn.LazyConv2d, torch.nn.functional.conv2d,
|
|
(32, 2), (192, 16, 8, 6), (32, 16, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv2d_pickle(self):
|
|
self._check_lazy_conv_pickle(nn.Conv2d, nn.LazyConv2d, (32, 2), (192, 16, 8, 6),
|
|
(32, 16, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv2d_state(self):
|
|
self._check_lazy_conv_state(lambda: nn.Conv2d(16, 32, 2),
|
|
lambda: nn.LazyConv2d(32, 2),
|
|
(32, 16, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv3d(self):
|
|
self._check_lazy_conv(nn.Conv3d, nn.LazyConv3d, torch.nn.functional.conv3d,
|
|
(32, 2), (192, 16, 8, 7, 6), (32, 16, 2, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv3d_pickle(self):
|
|
self._check_lazy_conv_pickle(nn.Conv3d, nn.LazyConv3d, (32, 2), (192, 16, 8, 7, 6),
|
|
(32, 16, 2, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv3d_state(self):
|
|
self._check_lazy_conv_state(lambda: nn.Conv3d(16, 32, 2),
|
|
lambda: nn.LazyConv3d(32, 2),
|
|
(32, 16, 2, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transposed1d(self):
|
|
self._check_lazy_conv(nn.ConvTranspose1d, nn.LazyConvTranspose1d, torch.nn.functional.conv_transpose1d,
|
|
(32, 2), (192, 16, 50), (16, 32, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose1d_pickle(self):
|
|
self._check_lazy_conv_pickle(nn.ConvTranspose1d, nn.LazyConvTranspose1d, (32, 2),
|
|
(192, 16, 50), (16, 32, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose1d_state(self):
|
|
self._check_lazy_conv_state(lambda: nn.ConvTranspose1d(16, 32, 2),
|
|
lambda: nn.LazyConvTranspose1d(32, 2),
|
|
(16, 32, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose2d(self):
|
|
self._check_lazy_conv(nn.ConvTranspose2d, nn.LazyConvTranspose2d, torch.nn.functional.conv_transpose2d,
|
|
(32, 2), (192, 16, 8, 6), (16, 32, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose2d_pickle(self):
|
|
self._check_lazy_conv_pickle(nn.ConvTranspose2d, nn.LazyConvTranspose2d, (32, 2),
|
|
(192, 16, 8, 6), (16, 32, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose2d_state(self):
|
|
self._check_lazy_conv_state(lambda: nn.ConvTranspose2d(16, 32, 2),
|
|
lambda: nn.LazyConvTranspose2d(32, 2),
|
|
(16, 32, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose3d(self):
|
|
self._check_lazy_conv(nn.ConvTranspose3d, nn.LazyConvTranspose3d, torch.nn.functional.conv_transpose3d,
|
|
(32, 2), (192, 16, 8, 7, 6), (16, 32, 2, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose3d_pickle(self):
|
|
self._check_lazy_conv_pickle(nn.ConvTranspose3d, nn.LazyConvTranspose3d, (32, 2),
|
|
(192, 16, 8, 7, 6), (16, 32, 2, 2, 2), (32,))
|
|
|
|
@suppress_warnings
|
|
def test_lazy_conv_transpose3d_state(self):
|
|
self._check_lazy_conv_state(lambda: nn.ConvTranspose3d(16, 32, 2),
|
|
lambda: nn.LazyConvTranspose3d(32, 2),
|
|
(16, 32, 2, 2, 2), (32,))
|
|
|
|
def _check_lazy_norm(self, cls, lazy_cls, input_shape):
|
|
for affine in [False, True]:
|
|
for track_running_stats in [False, True]:
|
|
lazy_module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
|
|
|
|
if affine:
|
|
self.assertIsInstance(lazy_module.weight, UninitializedParameter)
|
|
self.assertIsInstance(lazy_module.bias, UninitializedParameter)
|
|
if track_running_stats:
|
|
self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
|
|
self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)
|
|
|
|
input = torch.ones(*input_shape)
|
|
lazy_output = lazy_module(input)
|
|
self.assertIsInstance(lazy_module, cls)
|
|
self.assertNotIsInstance(lazy_module, lazy_cls)
|
|
|
|
num_features = input_shape[1]
|
|
module = cls(num_features, affine=affine, track_running_stats=track_running_stats)
|
|
expected_output = module(input)
|
|
|
|
self.assertEqual(lazy_output, expected_output)
|
|
if module.weight is not None:
|
|
self.assertEqual(lazy_module.weight.shape, module.weight.shape)
|
|
self.assertEqual(lazy_module.weight, module.weight)
|
|
if module.bias is not None:
|
|
self.assertEqual(lazy_module.bias.shape, module.bias.shape)
|
|
self.assertEqual(lazy_module.bias, module.bias)
|
|
if module.running_mean is not None:
|
|
self.assertEqual(lazy_module.running_mean.shape, module.running_mean.shape)
|
|
self.assertEqual(lazy_module.running_mean, module.running_mean)
|
|
if module.running_var is not None:
|
|
self.assertEqual(lazy_module.running_var.shape, module.running_var.shape)
|
|
self.assertEqual(lazy_module.running_var, module.running_var)
|
|
if module.num_batches_tracked is not None:
|
|
self.assertEqual(lazy_module.num_batches_tracked.shape, module.num_batches_tracked.shape)
|
|
self.assertEqual(lazy_module.num_batches_tracked, module.num_batches_tracked)
|
|
|
|
def _check_lazy_norm_pickle(self, cls, lazy_cls, input_shape):
|
|
for affine in [False, True]:
|
|
for track_running_stats in [False, True]:
|
|
module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
|
|
module = pickle.loads(pickle.dumps(module))
|
|
|
|
self.assertIsInstance(module, lazy_cls)
|
|
if affine:
|
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
|
if track_running_stats:
|
|
self.assertIsInstance(module.running_mean, UninitializedBuffer)
|
|
self.assertIsInstance(module.running_var, UninitializedBuffer)
|
|
|
|
input = torch.ones(*input_shape)
|
|
module(input) # fully materialized
|
|
module = pickle.loads(pickle.dumps(module))
|
|
|
|
self.assertNotIsInstance(module, lazy_cls)
|
|
self.assertIsInstance(module, cls)
|
|
if affine:
|
|
self.assertNotIsInstance(module.weight, UninitializedParameter)
|
|
self.assertNotIsInstance(module.bias, UninitializedParameter)
|
|
if track_running_stats:
|
|
self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
|
|
self.assertNotIsInstance(module.running_var, UninitializedBuffer)
|
|
|
|
def _check_lazy_batchnorm_state(self, cls, lazy_cls):
|
|
module = cls(10)
|
|
lazy_module = lazy_cls(affine=True, track_running_stats=True)
|
|
lazy_module.load_state_dict(module.state_dict())
|
|
# Parameters have been initialized but the module won't become a full
|
|
# Conv one until the first iteration. This is due to
|
|
# limitations on the state_dict loading logic
|
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
|
self.assertEqual(lazy_module.weight.shape, (10,))
|
|
self.assertEqual(lazy_module.bias.shape, (10,))
|
|
self.assertEqual(lazy_module.running_mean.shape, (10,))
|
|
self.assertEqual(lazy_module.running_var.shape, (10,))
|
|
|
|
module = cls(10)
|
|
lazy_module = lazy_cls()
|
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
|
module.load_state_dict(lazy_module.state_dict())
|
|
|
|
def _check_lazy_instancenorm_state(self, cls, lazy_cls):
|
|
for affine in [False, True]:
|
|
for track_running_stats in [False, True]:
|
|
module = cls(10, affine=affine, track_running_stats=track_running_stats)
|
|
lazy_module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
|
|
lazy_module.load_state_dict(module.state_dict())
|
|
# Parameters have been initialized but the module won't become a full
|
|
# InstanceNorm one until the first iteration. This is due to
|
|
# limitations on the state_dict loading logic
|
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
|
if affine:
|
|
self.assertEqual(lazy_module.weight.shape, (10,))
|
|
self.assertEqual(lazy_module.bias.shape, (10,))
|
|
if track_running_stats:
|
|
self.assertEqual(lazy_module.running_mean.shape, (10,))
|
|
self.assertEqual(lazy_module.running_var.shape, (10,))
|
|
|
|
module = cls(10, affine=True, track_running_stats=True)
|
|
lazy_module = lazy_cls(affine=True, track_running_stats=True)
|
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
|
module.load_state_dict(lazy_module.state_dict())
|
|
|
|
def _check_lazy_norm_with_dict_input(self, cls, lazy_cls, input_shape):
|
|
input = {"input": torch.ones(*input_shape)}
|
|
|
|
lazy_module = lazy_cls()
|
|
lazy_output = lazy_module(**input)
|
|
|
|
num_features = input_shape[1]
|
|
module = cls(num_features)
|
|
expected_output = module(**input)
|
|
|
|
self.assertEqual(lazy_output, expected_output)
|
|
|
|
def test_lazy_batchnorm1d(self):
|
|
self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
|
|
self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
|
|
|
|
def test_lazy_batchnorm1d_pickle(self):
|
|
self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
|
|
self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
|
|
|
|
def test_lazy_batchnorm1d_state(self):
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
|
|
|
|
def test_lazy_batchnorm2d(self):
|
|
self._check_lazy_norm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
|
|
|
|
def test_lazy_batchnorm2d_pickle(self):
|
|
self._check_lazy_norm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
|
|
|
|
def test_lazy_batchnorm2d_state(self):
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
|
|
|
|
def test_lazy_batchnorm3d(self):
|
|
self._check_lazy_norm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
|
|
|
|
def test_lazy_batchnorm3d_pickle(self):
|
|
self._check_lazy_norm_pickle(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
|
|
|
|
def test_lazy_batchnorm3d_state(self):
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
|
|
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
|
|
|
|
def test_lazy_instancenorm1d(self):
|
|
self._check_lazy_norm(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6))
|
|
|
|
def test_lazy_instancenorm1d_pickle(self):
|
|
self._check_lazy_norm_pickle(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6))
|
|
|
|
def test_lazy_instancenorm1d_state(self):
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
|
|
|
|
def test_lazy_instancenorm2d(self):
|
|
self._check_lazy_norm(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7))
|
|
|
|
def test_lazy_instancenorm2d_pickle(self):
|
|
self._check_lazy_norm_pickle(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7))
|
|
|
|
def test_lazy_instancenorm2d_state(self):
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
|
|
|
|
def test_lazy_instancenorm3d(self):
|
|
self._check_lazy_norm(nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8))
|
|
|
|
def test_lazy_instancenorm3d_pickle(self):
|
|
self._check_lazy_norm_pickle(nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8))
|
|
|
|
def test_lazy_instancenorm3d_state(self):
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
|
|
self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
|
|
|
|
def test_lazy_batchnorm_with_dict_input(self):
|
|
self._check_lazy_norm_with_dict_input(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
|
|
self._check_lazy_norm_with_dict_input(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
|
|
self._check_lazy_norm_with_dict_input(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
|
|
|
|
@suppress_warnings
|
|
def test_materialize_dtype(self):
|
|
module = LazyModule()
|
|
module.register_parameter('test_param', UninitializedParameter())
|
|
module.test_param.materialize(10)
|
|
self.assertTrue(module.test_param.dtype == torch.get_default_dtype())
|
|
module = LazyModule()
|
|
module.register_parameter('test_param', UninitializedParameter())
|
|
module.half()
|
|
module.test_param.materialize(10)
|
|
self.assertTrue(module.test_param.dtype == torch.float16)
|
|
|
|
@unittest.skipIf(not (TEST_CUDA or TEST_PRIVATEUSE1), 'CUDA and PRIVATEUSE1 not available')
|
|
@suppress_warnings
|
|
def test_materialize_device(self):
|
|
module = LazyModule()
|
|
module.register_parameter('test_param', UninitializedParameter())
|
|
module.test_param.materialize(10)
|
|
self.assertTrue(module.test_param.device.type == 'cpu')
|
|
if TEST_CUDA:
|
|
device = 'cuda'
|
|
elif TEST_PRIVATEUSE1:
|
|
device = torch._C._get_privateuse1_backend_name()
|
|
module = LazyModule()
|
|
module.register_parameter('test_param', UninitializedParameter())
|
|
module.to(device)
|
|
module.test_param.materialize(10)
|
|
self.assertTrue(module.test_param.device.type == device)
|
|
|
|
@suppress_warnings
|
|
def test_chained_initialization(self):
|
|
class MyNetwork(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear_1 = torch.nn.LazyLinear(15)
|
|
self.linear_2 = torch.nn.LazyLinear(10)
|
|
|
|
def forward(self, x):
|
|
y = self.linear_1(x)
|
|
return self.linear_2(y)
|
|
|
|
net = MyNetwork()
|
|
net(torch.ones(5, 10))
|
|
self.assertTrue(net.linear_1.weight.shape == (15, 10))
|
|
self.assertTrue(net.linear_1.bias.shape == (15,))
|
|
self.assertTrue(net.linear_2.weight.shape == (10, 15))
|
|
self.assertTrue(net.linear_2.bias.shape == (10,))
|
|
|
|
@suppress_warnings
|
|
def test_optimizer_pass(self):
|
|
# Add Adamax and RAdam when #118230 and #117836 are complete
|
|
optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam,
|
|
torch.optim.AdamW, torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
|
|
torch.optim.RMSprop, torch.optim.LBFGS]
|
|
|
|
def run_step(module, optim):
|
|
self.assertIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter)
|
|
module.test_param.materialize(10)
|
|
self.assertIsInstance(optim.param_groups[0]['params'][0], Parameter)
|
|
self.assertNotIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter)
|
|
for p in module.parameters():
|
|
p.grad = torch.rand_like(p)
|
|
if isinstance(optim, torch.optim.LBFGS):
|
|
optim.step(lambda: 1.0)
|
|
else:
|
|
optim.step()
|
|
|
|
for optim_cls in optimizers:
|
|
module = LazyModule()
|
|
module.register_parameter('test_param', UninitializedParameter())
|
|
if optim_cls is torch.optim.SGD:
|
|
optim = optim_cls(module.parameters(), lr=0.0)
|
|
elif optim_cls is torch.optim.Adagrad:
|
|
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
|
|
optim = optim_cls(module.parameters())
|
|
continue
|
|
else:
|
|
optim = optim_cls(module.parameters())
|
|
run_step(module, optim)
|
|
|
|
@suppress_warnings
|
|
def test_weight_norm(self):
|
|
m = nn.LazyLinear(7)
|
|
with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'):
|
|
m = torch.nn.utils.weight_norm(m)
|
|
|
|
@suppress_warnings
|
|
def test_spectral_norm(self):
|
|
m = nn.LazyLinear(7)
|
|
with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'):
|
|
m = torch.nn.utils.spectral_norm(m)
|
|
|
|
@suppress_warnings
|
|
def test_invalid_functions(self):
|
|
param = torch.nn.parameter.UninitializedParameter()
|
|
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
|
|
torch.empty_like(param)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
|
|
torch.add(param, param)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
|
|
param + param
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|