pytorch/test/nn/test_lazy_modules.py
Jane Xu d947b9d500 [optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
2024-02-02 02:02:58 +00:00

642 lines
29 KiB
Python

# Owner(s): ["module: nn"]
import unittest
import pickle
import torch
import torch.nn as nn
from torch.nn.parameter import UninitializedParameter, UninitializedBuffer
from torch.nn import Parameter
from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings, TEST_PRIVATEUSE1
from torch.testing._internal.common_cuda import TEST_CUDA
class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
pass
class TestLazyModules(TestCase):
@suppress_warnings
def test_lazy_module_parameter(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
state_dict = module.state_dict()
self.assertIsInstance(state_dict['test_param'], UninitializedParameter)
new_module = LazyModule()
# An error is raised when there is an attempt to replace an existing parameter
# with an uninitialized one
new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5)))
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
new_module.load_state_dict(state_dict)
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
new_module = LazyModule()
new_module.register_parameter('test_param', nn.Parameter(torch.ones(5, 5)))
module.load_state_dict(new_module.state_dict())
self.assertEqual(module.test_param, torch.ones((5, 5)))
# Uninitialized parameters are left unchanged
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
new_module = LazyModule()
new_module.register_parameter('test_param', UninitializedParameter())
module.load_state_dict(new_module.state_dict())
self.assertTrue(module.has_uninitialized_params())
@suppress_warnings
def test_lazy_module_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
state_dict = module.state_dict()
self.assertIsInstance(state_dict['test_buffer'], UninitializedBuffer)
new_module = LazyModule()
# An error is raised when there is an attempt to replace an existing parameter
# with an uninitialized one
new_module.register_buffer('test_buffer', torch.ones(5, 5))
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
new_module.load_state_dict(state_dict)
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
new_module = LazyModule()
new_module.register_buffer('test_buffer', torch.ones(5, 5))
module.load_state_dict(new_module.state_dict())
self.assertEqual(module.test_buffer, torch.ones((5, 5)))
# Uninitialized parameters are left unchanged
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
new_module = LazyModule()
new_module.register_buffer('test_buffer', UninitializedBuffer())
module.load_state_dict(new_module.state_dict())
module.load_state_dict(new_module.state_dict())
self.assertTrue(module.has_uninitialized_params())
@suppress_warnings
def test_lazy_module_jit_param(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
torch.jit.script(module)
@suppress_warnings
def test_lazy_module_jit_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
torch.jit.script(module)
@suppress_warnings
def test_lazy_share_memory_param(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
module.share_memory()
@suppress_warnings
def test_lazy_share_memory_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
module.share_memory()
@suppress_warnings
def test_linear(self):
module = nn.LazyLinear(10)
self.assertIsInstance(module.weight, UninitializedParameter)
self.assertIsInstance(module.bias, UninitializedParameter)
input = torch.ones(5, 5)
module(input)
self.assertIsInstance(module, nn.Linear)
self.assertNotIsInstance(module, nn.LazyLinear)
self.assertTrue(module.weight.shape == (10, 5))
self.assertTrue(module.bias.shape == (10,))
y = module(input)
self.assertTrue(torch.equal(torch.nn.functional.linear(input, module.weight, module.bias), y))
@suppress_warnings
def test_lazy_linear_pickle(self):
module = nn.LazyLinear(10)
self.assertIsInstance(module.weight, UninitializedParameter)
self.assertIsInstance(module.bias, UninitializedParameter)
module = pickle.loads(pickle.dumps(module))
self.assertIsInstance(module, nn.LazyLinear)
self.assertIsInstance(module.weight, UninitializedParameter)
self.assertIsInstance(module.bias, UninitializedParameter)
input = torch.ones(5, 5)
module(input) # fully materialized
new_module = pickle.loads(pickle.dumps(module))
self.assertIsInstance(new_module, nn.Linear)
self.assertNotIsInstance(new_module, nn.LazyLinear)
self.assertTrue(new_module.weight.shape == (10, 5))
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
self.assertTrue(new_module.bias.shape == (10,))
self.assertNotIsInstance(new_module.bias, UninitializedParameter)
@suppress_warnings
def test_linear_state(self):
module = nn.Linear(5, 10)
lazy_module = nn.LazyLinear(10)
lazy_module.load_state_dict(module.state_dict())
# Parameters have been initialized but the module won't become a full
# Linear one until the first iteration. This is due to
# limitations on the state_dict loading logic
self.assertFalse(lazy_module.has_uninitialized_params())
self.assertTrue(lazy_module.weight.shape == (10, 5))
self.assertTrue(lazy_module.bias.shape == (10,))
module = nn.Linear(5, 10)
lazy_module = nn.LazyLinear(10)
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
module.load_state_dict(lazy_module.state_dict())
def _check_lazy_conv(self, cls, lazy_cls, func, init_args, input_shape,
expected_weight_shape, expected_bias_shape):
module = lazy_cls(*init_args)
self.assertIsInstance(module.weight, UninitializedParameter)
if module.bias is not None:
self.assertIsInstance(module.bias, UninitializedParameter)
input = torch.ones(*input_shape)
module(input)
self.assertIsInstance(module, cls)
self.assertNotIsInstance(module, lazy_cls)
self.assertEqual(module.weight.shape, expected_weight_shape)
if module.bias is not None:
self.assertEqual(module.bias.shape, expected_bias_shape)
y = module(input)
self.assertTrue(torch.equal(func(input, module.weight, module.bias), y))
def _check_lazy_conv_pickle(self, cls, lazy_cls, init_args, input_shape,
expected_weight_shape, expected_bias_shape):
module = lazy_cls(*init_args)
self.assertIsInstance(module.weight, UninitializedParameter)
if module.bias is not None:
self.assertIsInstance(module.bias, UninitializedParameter)
module = pickle.loads(pickle.dumps(module))
self.assertIsInstance(module, lazy_cls)
self.assertIsInstance(module.weight, UninitializedParameter)
if module.bias is not None:
self.assertIsInstance(module.bias, UninitializedParameter)
input = torch.ones(*input_shape)
module(input) # fully materialized
new_module = pickle.loads(pickle.dumps(module))
self.assertIsInstance(new_module, cls)
self.assertNotIsInstance(new_module, lazy_cls)
self.assertEqual(new_module.weight.shape, expected_weight_shape)
self.assertNotIsInstance(new_module.weight, UninitializedParameter)
if new_module.bias is not None:
self.assertEqual(new_module.bias.shape, expected_bias_shape)
self.assertNotIsInstance(new_module.bias, UninitializedParameter)
def _check_lazy_conv_state(self, gen_module, gen_lazy_module,
expected_weight_shape, expected_bias_shape):
module = gen_module()
lazy_module = gen_lazy_module()
lazy_module.load_state_dict(module.state_dict())
# Parameters have been initialized but the module won't become a full
# Conv one until the first iteration. This is due to
# limitations on the state_dict loading logic
self.assertFalse(lazy_module.has_uninitialized_params())
self.assertEqual(lazy_module.weight.shape, expected_weight_shape)
if lazy_module.bias is not None:
self.assertEqual(lazy_module.bias.shape, expected_bias_shape)
module = gen_module()
lazy_module = gen_lazy_module()
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
module.load_state_dict(lazy_module.state_dict())
def test_lazy_pre_forward_hook(self):
"""
This test is to test whether lazymodule can register other pre-forward hook
functions successfully.
"""
class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
def initialize_parameters(self, input):
return None
def forward(self, input):
return input
def hook_function(module, input):
return input[0] + 1
module = TestModule()
module.register_forward_pre_hook(hook_function)
output = module(torch.zeros(2, 2))
self.assertEqual(output, torch.ones(2, 2))
def test_lazy_forward_hook(self):
"""
This test is to test whether lazymodule can register other forward hook
functions successfully.
"""
class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
def initialize_parameters(self, input):
return None
def forward(self, input):
return input
def hook_function(module, input, output):
return input[0] + 1
module = TestModule()
module.register_forward_hook(hook_function)
output = module(torch.zeros(2, 2))
self.assertEqual(output, torch.ones(2, 2))
@suppress_warnings
def test_lazy_conv1d(self):
self._check_lazy_conv(nn.Conv1d, nn.LazyConv1d, torch.nn.functional.conv1d,
(32, 2), (192, 16, 50), (32, 16, 2), (32,))
@suppress_warnings
def test_lazy_conv1d_pickle(self):
self._check_lazy_conv_pickle(nn.Conv1d, nn.LazyConv1d, (32, 2), (192, 16, 50),
(32, 16, 2), (32,))
@suppress_warnings
def test_lazy_conv1d_state(self):
self._check_lazy_conv_state(lambda: nn.Conv1d(16, 32, 2),
lambda: nn.LazyConv1d(32, 2),
(32, 16, 2), (32,))
@suppress_warnings
def test_lazy_conv2d(self):
self._check_lazy_conv(nn.Conv2d, nn.LazyConv2d, torch.nn.functional.conv2d,
(32, 2), (192, 16, 8, 6), (32, 16, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv2d_pickle(self):
self._check_lazy_conv_pickle(nn.Conv2d, nn.LazyConv2d, (32, 2), (192, 16, 8, 6),
(32, 16, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv2d_state(self):
self._check_lazy_conv_state(lambda: nn.Conv2d(16, 32, 2),
lambda: nn.LazyConv2d(32, 2),
(32, 16, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv3d(self):
self._check_lazy_conv(nn.Conv3d, nn.LazyConv3d, torch.nn.functional.conv3d,
(32, 2), (192, 16, 8, 7, 6), (32, 16, 2, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv3d_pickle(self):
self._check_lazy_conv_pickle(nn.Conv3d, nn.LazyConv3d, (32, 2), (192, 16, 8, 7, 6),
(32, 16, 2, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv3d_state(self):
self._check_lazy_conv_state(lambda: nn.Conv3d(16, 32, 2),
lambda: nn.LazyConv3d(32, 2),
(32, 16, 2, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv_transposed1d(self):
self._check_lazy_conv(nn.ConvTranspose1d, nn.LazyConvTranspose1d, torch.nn.functional.conv_transpose1d,
(32, 2), (192, 16, 50), (16, 32, 2), (32,))
@suppress_warnings
def test_lazy_conv_transpose1d_pickle(self):
self._check_lazy_conv_pickle(nn.ConvTranspose1d, nn.LazyConvTranspose1d, (32, 2),
(192, 16, 50), (16, 32, 2), (32,))
@suppress_warnings
def test_lazy_conv_transpose1d_state(self):
self._check_lazy_conv_state(lambda: nn.ConvTranspose1d(16, 32, 2),
lambda: nn.LazyConvTranspose1d(32, 2),
(16, 32, 2), (32,))
@suppress_warnings
def test_lazy_conv_transpose2d(self):
self._check_lazy_conv(nn.ConvTranspose2d, nn.LazyConvTranspose2d, torch.nn.functional.conv_transpose2d,
(32, 2), (192, 16, 8, 6), (16, 32, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv_transpose2d_pickle(self):
self._check_lazy_conv_pickle(nn.ConvTranspose2d, nn.LazyConvTranspose2d, (32, 2),
(192, 16, 8, 6), (16, 32, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv_transpose2d_state(self):
self._check_lazy_conv_state(lambda: nn.ConvTranspose2d(16, 32, 2),
lambda: nn.LazyConvTranspose2d(32, 2),
(16, 32, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv_transpose3d(self):
self._check_lazy_conv(nn.ConvTranspose3d, nn.LazyConvTranspose3d, torch.nn.functional.conv_transpose3d,
(32, 2), (192, 16, 8, 7, 6), (16, 32, 2, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv_transpose3d_pickle(self):
self._check_lazy_conv_pickle(nn.ConvTranspose3d, nn.LazyConvTranspose3d, (32, 2),
(192, 16, 8, 7, 6), (16, 32, 2, 2, 2), (32,))
@suppress_warnings
def test_lazy_conv_transpose3d_state(self):
self._check_lazy_conv_state(lambda: nn.ConvTranspose3d(16, 32, 2),
lambda: nn.LazyConvTranspose3d(32, 2),
(16, 32, 2, 2, 2), (32,))
def _check_lazy_norm(self, cls, lazy_cls, input_shape):
for affine in [False, True]:
for track_running_stats in [False, True]:
lazy_module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
if affine:
self.assertIsInstance(lazy_module.weight, UninitializedParameter)
self.assertIsInstance(lazy_module.bias, UninitializedParameter)
if track_running_stats:
self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)
input = torch.ones(*input_shape)
lazy_output = lazy_module(input)
self.assertIsInstance(lazy_module, cls)
self.assertNotIsInstance(lazy_module, lazy_cls)
num_features = input_shape[1]
module = cls(num_features, affine=affine, track_running_stats=track_running_stats)
expected_output = module(input)
self.assertEqual(lazy_output, expected_output)
if module.weight is not None:
self.assertEqual(lazy_module.weight.shape, module.weight.shape)
self.assertEqual(lazy_module.weight, module.weight)
if module.bias is not None:
self.assertEqual(lazy_module.bias.shape, module.bias.shape)
self.assertEqual(lazy_module.bias, module.bias)
if module.running_mean is not None:
self.assertEqual(lazy_module.running_mean.shape, module.running_mean.shape)
self.assertEqual(lazy_module.running_mean, module.running_mean)
if module.running_var is not None:
self.assertEqual(lazy_module.running_var.shape, module.running_var.shape)
self.assertEqual(lazy_module.running_var, module.running_var)
if module.num_batches_tracked is not None:
self.assertEqual(lazy_module.num_batches_tracked.shape, module.num_batches_tracked.shape)
self.assertEqual(lazy_module.num_batches_tracked, module.num_batches_tracked)
def _check_lazy_norm_pickle(self, cls, lazy_cls, input_shape):
for affine in [False, True]:
for track_running_stats in [False, True]:
module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
module = pickle.loads(pickle.dumps(module))
self.assertIsInstance(module, lazy_cls)
if affine:
self.assertIsInstance(module.weight, UninitializedParameter)
self.assertIsInstance(module.bias, UninitializedParameter)
if track_running_stats:
self.assertIsInstance(module.running_mean, UninitializedBuffer)
self.assertIsInstance(module.running_var, UninitializedBuffer)
input = torch.ones(*input_shape)
module(input) # fully materialized
module = pickle.loads(pickle.dumps(module))
self.assertNotIsInstance(module, lazy_cls)
self.assertIsInstance(module, cls)
if affine:
self.assertNotIsInstance(module.weight, UninitializedParameter)
self.assertNotIsInstance(module.bias, UninitializedParameter)
if track_running_stats:
self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
self.assertNotIsInstance(module.running_var, UninitializedBuffer)
def _check_lazy_batchnorm_state(self, cls, lazy_cls):
module = cls(10)
lazy_module = lazy_cls(affine=True, track_running_stats=True)
lazy_module.load_state_dict(module.state_dict())
# Parameters have been initialized but the module won't become a full
# Conv one until the first iteration. This is due to
# limitations on the state_dict loading logic
self.assertFalse(lazy_module.has_uninitialized_params())
self.assertEqual(lazy_module.weight.shape, (10,))
self.assertEqual(lazy_module.bias.shape, (10,))
self.assertEqual(lazy_module.running_mean.shape, (10,))
self.assertEqual(lazy_module.running_var.shape, (10,))
module = cls(10)
lazy_module = lazy_cls()
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
module.load_state_dict(lazy_module.state_dict())
def _check_lazy_instancenorm_state(self, cls, lazy_cls):
for affine in [False, True]:
for track_running_stats in [False, True]:
module = cls(10, affine=affine, track_running_stats=track_running_stats)
lazy_module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
lazy_module.load_state_dict(module.state_dict())
# Parameters have been initialized but the module won't become a full
# InstanceNorm one until the first iteration. This is due to
# limitations on the state_dict loading logic
self.assertFalse(lazy_module.has_uninitialized_params())
if affine:
self.assertEqual(lazy_module.weight.shape, (10,))
self.assertEqual(lazy_module.bias.shape, (10,))
if track_running_stats:
self.assertEqual(lazy_module.running_mean.shape, (10,))
self.assertEqual(lazy_module.running_var.shape, (10,))
module = cls(10, affine=True, track_running_stats=True)
lazy_module = lazy_cls(affine=True, track_running_stats=True)
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
module.load_state_dict(lazy_module.state_dict())
def _check_lazy_norm_with_dict_input(self, cls, lazy_cls, input_shape):
input = {"input": torch.ones(*input_shape)}
lazy_module = lazy_cls()
lazy_output = lazy_module(**input)
num_features = input_shape[1]
module = cls(num_features)
expected_output = module(**input)
self.assertEqual(lazy_output, expected_output)
def test_lazy_batchnorm1d(self):
self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
def test_lazy_batchnorm1d_pickle(self):
self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
def test_lazy_batchnorm1d_state(self):
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
def test_lazy_batchnorm2d(self):
self._check_lazy_norm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
def test_lazy_batchnorm2d_pickle(self):
self._check_lazy_norm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
def test_lazy_batchnorm2d_state(self):
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
def test_lazy_batchnorm3d(self):
self._check_lazy_norm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
def test_lazy_batchnorm3d_pickle(self):
self._check_lazy_norm_pickle(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
def test_lazy_batchnorm3d_state(self):
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
def test_lazy_instancenorm1d(self):
self._check_lazy_norm(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6))
def test_lazy_instancenorm1d_pickle(self):
self._check_lazy_norm_pickle(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6))
def test_lazy_instancenorm1d_state(self):
self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
def test_lazy_instancenorm2d(self):
self._check_lazy_norm(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7))
def test_lazy_instancenorm2d_pickle(self):
self._check_lazy_norm_pickle(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7))
def test_lazy_instancenorm2d_state(self):
self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
def test_lazy_instancenorm3d(self):
self._check_lazy_norm(nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8))
def test_lazy_instancenorm3d_pickle(self):
self._check_lazy_norm_pickle(nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8))
def test_lazy_instancenorm3d_state(self):
self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
def test_lazy_batchnorm_with_dict_input(self):
self._check_lazy_norm_with_dict_input(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
self._check_lazy_norm_with_dict_input(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
self._check_lazy_norm_with_dict_input(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
@suppress_warnings
def test_materialize_dtype(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
module.test_param.materialize(10)
self.assertTrue(module.test_param.dtype == torch.get_default_dtype())
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
module.half()
module.test_param.materialize(10)
self.assertTrue(module.test_param.dtype == torch.float16)
@unittest.skipIf(not (TEST_CUDA or TEST_PRIVATEUSE1), 'CUDA and PRIVATEUSE1 not available')
@suppress_warnings
def test_materialize_device(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
module.test_param.materialize(10)
self.assertTrue(module.test_param.device.type == 'cpu')
if TEST_CUDA:
device = 'cuda'
elif TEST_PRIVATEUSE1:
device = torch._C._get_privateuse1_backend_name()
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
module.to(device)
module.test_param.materialize(10)
self.assertTrue(module.test_param.device.type == device)
@suppress_warnings
def test_chained_initialization(self):
class MyNetwork(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = torch.nn.LazyLinear(15)
self.linear_2 = torch.nn.LazyLinear(10)
def forward(self, x):
y = self.linear_1(x)
return self.linear_2(y)
net = MyNetwork()
net(torch.ones(5, 10))
self.assertTrue(net.linear_1.weight.shape == (15, 10))
self.assertTrue(net.linear_1.bias.shape == (15,))
self.assertTrue(net.linear_2.weight.shape == (10, 15))
self.assertTrue(net.linear_2.bias.shape == (10,))
@suppress_warnings
def test_optimizer_pass(self):
# Add Adamax and RAdam when #118230 and #117836 are complete
optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam,
torch.optim.AdamW, torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
torch.optim.RMSprop, torch.optim.LBFGS]
def run_step(module, optim):
self.assertIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter)
module.test_param.materialize(10)
self.assertIsInstance(optim.param_groups[0]['params'][0], Parameter)
self.assertNotIsInstance(optim.param_groups[0]['params'][0], UninitializedParameter)
for p in module.parameters():
p.grad = torch.rand_like(p)
if isinstance(optim, torch.optim.LBFGS):
optim.step(lambda: 1.0)
else:
optim.step()
for optim_cls in optimizers:
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
if optim_cls is torch.optim.SGD:
optim = optim_cls(module.parameters(), lr=0.0)
elif optim_cls is torch.optim.Adagrad:
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
optim = optim_cls(module.parameters())
continue
else:
optim = optim_cls(module.parameters())
run_step(module, optim)
@suppress_warnings
def test_weight_norm(self):
m = nn.LazyLinear(7)
with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'):
m = torch.nn.utils.weight_norm(m)
@suppress_warnings
def test_spectral_norm(self):
m = nn.LazyLinear(7)
with self.assertRaisesRegex(ValueError, 'have uninitialized parameters.'):
m = torch.nn.utils.spectral_norm(m)
@suppress_warnings
def test_invalid_functions(self):
param = torch.nn.parameter.UninitializedParameter()
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
torch.empty_like(param)
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
torch.add(param, param)
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
param + param
if __name__ == '__main__':
run_tests()