From aa1fd6b45af9aff7f90ca2693ba3b97452e4724c Mon Sep 17 00:00:00 2001 From: Akifumi Imanishi Date: Fri, 5 Feb 2021 10:23:22 -0800 Subject: [PATCH] Add LazyBatchNormXd (#51548) Summary: This PR implements UninitializedBuffer and LazyBatchnormXd based on https://github.com/pytorch/pytorch/issues/44538. (cc. emcastillo and albanD) Pull Request resolved: https://github.com/pytorch/pytorch/pull/51548 Reviewed By: zhangguanheng66 Differential Revision: D26276903 Pulled By: albanD fbshipit-source-id: 0ac706974178363f8af075e59b41d5989418922f --- docs/source/nn.rst | 3 + test/test_nn.py | 166 +++++++++++++++++++++++++++++++++- torch/jit/_recursive.py | 6 +- torch/nn/__init__.py | 2 +- torch/nn/modules/__init__.py | 4 +- torch/nn/modules/batchnorm.py | 120 +++++++++++++++++++++++- torch/nn/modules/lazy.py | 27 +++--- torch/nn/modules/module.py | 2 +- torch/nn/parameter.py | 86 ++++++++++++------ torch/nn/parameter.pyi | 8 ++ 10 files changed, 375 insertions(+), 49 deletions(-) diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 4e0beaae41f..0c1c5fb21f9 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -176,6 +176,9 @@ Normalization Layers nn.BatchNorm1d nn.BatchNorm2d nn.BatchNorm3d + nn.LazyBatchNorm1d + nn.LazyBatchNorm2d + nn.LazyBatchNorm3d nn.GroupNorm nn.SyncBatchNorm nn.InstanceNorm1d diff --git a/test/test_nn.py b/test/test_nn.py index 4df9fccacbf..89bcc9b8ad4 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -30,7 +30,7 @@ from torch.nn.utils import clip_grad_norm_, clip_grad_value_ import torch.nn.utils.prune as prune from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.nn import Parameter -from torch.nn.parameter import UninitializedParameter +from torch.nn.parameter import UninitializedParameter, UninitializedBuffer from torch.nn.parallel._functions import Broadcast from torch.testing import get_all_fp_dtypes from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ @@ -14276,7 +14276,37 @@ class TestLazyModules(TestCase): self.assertTrue(module.has_uninitialized_params()) @suppress_warnings - def test_lazy_module_jit(self): + def test_lazy_module_buffer(self): + module = LazyModule() + module.register_buffer('test_buffer', UninitializedBuffer()) + self.assertTrue(module.has_uninitialized_params()) + state_dict = module.state_dict() + self.assertIsInstance(state_dict['test_buffer'], UninitializedBuffer) + new_module = LazyModule() + # An error is raised when there is an attempt to replace an existing parameter + # with an uninitialized one + new_module.register_buffer('test_buffer', torch.ones(5, 5)) + with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'): + new_module.load_state_dict(state_dict) + # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one + new_module = LazyModule() + new_module.register_buffer('test_buffer', torch.ones(5, 5)) + module.load_state_dict(new_module.state_dict()) + self.assertEqual(module.test_buffer, torch.ones((5, 5))) + + # Uninitialized parameters are left unchanged + module = LazyModule() + module.register_buffer('test_buffer', UninitializedBuffer()) + self.assertTrue(module.has_uninitialized_params()) + + new_module = LazyModule() + new_module.register_buffer('test_buffer', UninitializedBuffer()) + module.load_state_dict(new_module.state_dict()) + module.load_state_dict(new_module.state_dict()) + self.assertTrue(module.has_uninitialized_params()) + + @suppress_warnings + def test_lazy_module_jit_param(self): module = LazyModule() module.register_parameter('test_param', UninitializedParameter()) self.assertTrue(module.has_uninitialized_params()) @@ -14284,13 +14314,29 @@ class TestLazyModules(TestCase): torch.jit.script(module) @suppress_warnings - def test_lazy_share_memory(self): + def test_lazy_module_jit_buffer(self): + module = LazyModule() + module.register_buffer('test_buffer', UninitializedBuffer()) + self.assertTrue(module.has_uninitialized_params()) + with self.assertRaisesRegex(RuntimeError, 'run a forward pass'): + torch.jit.script(module) + + @suppress_warnings + def test_lazy_share_memory_param(self): module = LazyModule() module.register_parameter('test_param', UninitializedParameter()) self.assertTrue(module.has_uninitialized_params()) with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'): module.share_memory() + @suppress_warnings + def test_lazy_share_memory_buffer(self): + module = LazyModule() + module.register_buffer('test_buffer', UninitializedBuffer()) + self.assertTrue(module.has_uninitialized_params()) + with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'): + module.share_memory() + @suppress_warnings def test_linear(self): module = nn.LazyLinear(10) @@ -14464,6 +14510,120 @@ class TestLazyModules(TestCase): lambda: nn.LazyConvTranspose3d(32, 2), (16, 32, 2, 2, 2)) + def _check_lazy_batchnorm(self, cls, lazy_cls, input_shape): + for affine in [False, True]: + for track_running_stats in [False, True]: + lazy_module = lazy_cls(affine=affine, track_running_stats=track_running_stats) + + if affine: + self.assertIsInstance(lazy_module.weight, UninitializedParameter) + self.assertIsInstance(lazy_module.bias, UninitializedParameter) + if track_running_stats: + self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer) + self.assertIsInstance(lazy_module.running_var, UninitializedBuffer) + + input = torch.ones(*input_shape) + y = lazy_module(input) + self.assertIsInstance(lazy_module, cls) + self.assertNotIsInstance(lazy_module, lazy_cls) + + num_features = input_shape[1] + module = cls(num_features, affine=affine, track_running_stats=track_running_stats) + expected = module(input) + + if module.weight is not None: + self.assertEqual(lazy_module.weight.shape, module.weight.shape) + self.assertEqual(lazy_module.weight, module.weight) + if module.bias is not None: + self.assertEqual(lazy_module.bias.shape, module.bias.shape) + self.assertEqual(lazy_module.bias, module.bias) + if module.running_mean is not None: + self.assertEqual(lazy_module.running_mean.shape, module.running_mean.shape) + self.assertEqual(lazy_module.running_mean, module.running_mean) + if module.running_var is not None: + self.assertEqual(lazy_module.running_var.shape, module.running_var.shape) + self.assertEqual(lazy_module.running_var, module.running_var) + if module.num_batches_tracked is not None: + self.assertEqual(lazy_module.num_batches_tracked.shape, module.num_batches_tracked.shape) + self.assertEqual(lazy_module.num_batches_tracked, module.num_batches_tracked) + + def _check_lazy_batchnorm_pickle(self, cls, lazy_cls, input_shape): + for affine in [False, True]: + for track_running_stats in [False, True]: + module = lazy_cls(affine=affine, track_running_stats=track_running_stats) + module = pickle.loads(pickle.dumps(module)) + + self.assertIsInstance(module, lazy_cls) + if affine: + self.assertIsInstance(module.weight, UninitializedParameter) + self.assertIsInstance(module.bias, UninitializedParameter) + if track_running_stats: + self.assertIsInstance(module.running_mean, UninitializedBuffer) + self.assertIsInstance(module.running_var, UninitializedBuffer) + + input = torch.ones(*input_shape) + module(input) # fully materialized + module = pickle.loads(pickle.dumps(module)) + + self.assertNotIsInstance(module, lazy_cls) + self.assertIsInstance(module, cls) + if affine: + self.assertNotIsInstance(module.weight, UninitializedParameter) + self.assertNotIsInstance(module.bias, UninitializedParameter) + if track_running_stats: + self.assertNotIsInstance(module.running_mean, UninitializedBuffer) + self.assertNotIsInstance(module.running_var, UninitializedBuffer) + + def _check_lazy_batchnorm_state(self, cls, lazy_cls): + module = cls(10) + lazy_module = lazy_cls(affine=True, track_running_stats=True) + lazy_module.load_state_dict(module.state_dict()) + # Parameters have been initialized but the module won't become a full + # Conv one until the first iteration. This is due to + # limitations on the state_dict loading logic + self.assertFalse(lazy_module.has_uninitialized_params()) + self.assertEqual(lazy_module.weight.shape, (10,)) + self.assertEqual(lazy_module.bias.shape, (10,)) + self.assertEqual(lazy_module.running_mean.shape, (10,)) + self.assertEqual(lazy_module.running_var.shape, (10,)) + + module = cls(10) + lazy_module = lazy_cls() + with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'): + module.load_state_dict(lazy_module.state_dict()) + + def test_lazy_batchnorm1d(self): + self._check_lazy_batchnorm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6)) + self._check_lazy_batchnorm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6)) + + def test_lazy_batchnorm1d_pickle(self): + self._check_lazy_batchnorm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6)) + self._check_lazy_batchnorm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6)) + + def test_lazy_batchnorm1d_state(self): + self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d) + self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d) + + def test_lazy_batchnorm2d(self): + self._check_lazy_batchnorm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7)) + + def test_lazy_batchnorm2d_pickle(self): + self._check_lazy_batchnorm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7)) + + def test_lazy_batchnorm2d_state(self): + self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d) + self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d) + + def test_lazy_batchnorm3d(self): + self._check_lazy_batchnorm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)) + + def test_lazy_batchnorm3d_pickle(self): + self._check_lazy_batchnorm_pickle(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)) + + def test_lazy_batchnorm3d_state(self): + self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d) + self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d) + @suppress_warnings def test_materialize_dtype(self): module = LazyModule() diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index d9c3417281f..3150866f95b 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -597,9 +597,13 @@ def check_module_initialized(mod): # This is to avoid importing torch.distributed.nn if not hasattr(mod, 'remote_parameters'): for name, param in mod._parameters.items(): - if isinstance(param, torch.nn.parameter.UninitializedParameter): + if torch.nn.parameter.is_lazy(param): raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?" .format(torch.typename(type(mod)), name)) + for name, buf in mod._buffers.items(): + if torch.nn.parameter.is_lazy(buf): + raise RuntimeError("'{}' has uninitialized buffers {}. Did you forget to run a forward pass?" + .format(torch.typename(type(mod)), name)) def infer_methods_to_compile(nn_module): """ diff --git a/torch/nn/__init__.py b/torch/nn/__init__.py index 82d7c4341d5..024a409eb3d 100644 --- a/torch/nn/__init__.py +++ b/torch/nn/__init__.py @@ -1,5 +1,5 @@ from .modules import * -from .parameter import Parameter, UninitializedParameter +from .parameter import Parameter, UninitializedParameter, UninitializedBuffer from .parallel import DataParallel from . import init from . import utils diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py index e6a7543bd9b..476713ca9f1 100644 --- a/torch/nn/modules/__init__.py +++ b/torch/nn/modules/__init__.py @@ -15,7 +15,8 @@ from .container import Container, Sequential, ModuleList, ModuleDict, ParameterL from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \ MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \ AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d -from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm +from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \ + LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout @@ -58,5 +59,6 @@ __all__ = [ 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', 'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d', + 'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d', 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle' ] diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 64417069e2b..d41d776b4ff 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -1,8 +1,9 @@ import torch from torch import Tensor from ._functions import SyncBatchNorm as sync_batch_norm +from .lazy import LazyModuleMixin from .module import Module -from torch.nn.parameter import Parameter +from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer from .. import functional as F from .. import init @@ -140,6 +141,36 @@ class _BatchNorm(_NormBase): self.weight, self.bias, bn_training, exponential_average_factor, self.eps) +class _LazyBatchNorm(LazyModuleMixin, _BatchNorm): + + def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + super(_LazyBatchNorm, self).__init__( + 0, eps, momentum, affine, track_running_stats) + if self.affine: + self.weight = UninitializedParameter() + self.bias = UninitializedParameter() + if self.track_running_stats: + self.running_mean = UninitializedBuffer() + self.running_var = UninitializedBuffer() + + def reset_parameters(self) -> None: + if not self.has_uninitialized_params() and self.num_features != 0: + super().reset_parameters() + + def initialize_parameters(self, input) -> None: # type: ignore + if self.has_uninitialized_params(): + self.num_features = input.shape[1] + if self.affine: + assert isinstance(self.weight, UninitializedParameter) + assert isinstance(self.bias, UninitializedParameter) + self.weight.materialize((self.num_features,)) + self.bias.materialize((self.num_features,)) + if self.track_running_stats: + self.running_mean.materialize((self.num_features,)) + self.running_var.materialize((self.num_features,)) + self.reset_parameters() + + class BatchNorm1d(_BatchNorm): r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper @@ -213,6 +244,35 @@ class BatchNorm1d(_BatchNorm): .format(input.dim())) +class LazyBatchNorm1d(_LazyBatchNorm): + r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization of + the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred + from the ``input.size(1)``. + + Args: + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + """ + + cls_to_become = BatchNorm1d # type: ignore[assignment] + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + + class BatchNorm2d(_BatchNorm): r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper @@ -286,6 +346,35 @@ class BatchNorm2d(_BatchNorm): .format(input.dim())) +class LazyBatchNorm2d(_LazyBatchNorm): + r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization of + the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred + from the ``input.size(1)``. + + Args: + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + """ + + cls_to_become = BatchNorm2d # type: ignore[assignment] + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + + class BatchNorm3d(_BatchNorm): r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper @@ -360,6 +449,35 @@ class BatchNorm3d(_BatchNorm): .format(input.dim())) +class LazyBatchNorm3d(_LazyBatchNorm): + r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization of + the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred + from the ``input.size(1)``. + + Args: + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + affine: a boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + track_running_stats: a boolean value that when set to ``True``, this + module tracks the running mean and variance, and when set to ``False``, + this module does not track such statistics, and initializes statistics + buffers :attr:`running_mean` and :attr:`running_var` as ``None``. + When these buffers are ``None``, this module always uses batch statistics. + in both training and eval modes. Default: ``True`` + """ + + cls_to_become = BatchNorm3d # type: ignore[assignment] + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + + class SyncBatchNorm(_BatchNorm): r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index dbb4eb6c36e..eb2e14bfe7a 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -3,7 +3,7 @@ from typing_extensions import Protocol import warnings import torch -from ..parameter import UninitializedParameter +from ..parameter import is_lazy class _LazyProtocol(Protocol): @@ -181,13 +181,14 @@ class LazyModuleMixin: # which is not clean for name, param in self._parameters.items(): if param is not None: - if isinstance(param, UninitializedParameter): - destination[prefix + name] = param - else: - destination[prefix + name] = param if keep_vars else param.detach() + if not (is_lazy(param) or keep_vars): + param = param.detach() + destination[prefix + name] = param for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: - destination[prefix + name] = buf if keep_vars else buf.detach() + if not (is_lazy(buf) or keep_vars): + buf = buf.detach() + destination[prefix + name] = buf def _lazy_load_hook( self: _LazyProtocol, state_dict, prefix, local_metadata, strict, @@ -201,15 +202,14 @@ class LazyModuleMixin: See comment in ``torch.nn.Module._register_load_state_dict_pre_hook`` for the details of the hook specification. """ - local_state = {k: v for k, v in self._parameters.items() if v is not None} - for name, param in local_state.items(): + for name, param in itertools.chain(self._parameters.items(), self._buffers.items()): key = prefix + name - if key in state_dict: + if key in state_dict and param is not None: input_param = state_dict[key] - if isinstance(param, UninitializedParameter): + if is_lazy(param): # The current parameter is not initialized but the one being loaded one is # create a new parameter based on the uninitialized one - if not isinstance(input_param, UninitializedParameter): + if not is_lazy(input_param): with torch.no_grad(): param.materialize(input_param.shape) @@ -226,8 +226,9 @@ class LazyModuleMixin: # This is to avoid the JIT to track this parameter and force # custom modules __setstate__ to add it params = self._parameters.values() - for param in itertools.chain(params): - if isinstance(param, (UninitializedParameter)): + buffers = self._buffers.values() + for param in itertools.chain(params, buffers): + if is_lazy(param): return True return False diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f99a588a2e3..d2a4dac73f0 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1137,7 +1137,7 @@ class Module: # This is used to avoid copying uninitialized parameters into # non-lazy modules, since they dont have the hook to do the checks # in such case, it will error when accessing the .shape attribute. - is_param_lazy = isinstance(param, torch.nn.parameter.UninitializedParameter) + is_param_lazy = torch.nn.parameter.is_lazy(param) # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: input_param = input_param[0] diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index c5d63708e90..ed53f468465 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -46,18 +46,7 @@ class Parameter(torch.Tensor): __torch_function__ = _disabled_torch_function_impl -class UninitializedParameter(Parameter): - r"""A parameter that is not initialized. - - Unitialized Parameters are a a special case of :class:`torch.nn.Parameter` - where the shape of the data is still unknown. - - Unlikely a :class:`torch.nn.Parameter`, uninitialized parameters - hold no data and attempting to access some properties, like their shape, - will throw a runtime error. The only operations that can be performed on a uninitialized - parameter are changing its datatype, moving it to a different device and - converting it to a regular :class:`torch.nn.Parameter`. - """ +class UninitializedTensorMixin: _allowed_methods = [ torch.Tensor.__hash__, torch.Tensor.size, @@ -74,16 +63,13 @@ class UninitializedParameter(Parameter): torch.Tensor.cpu, torch.Tensor.to, torch.Tensor.get_device, - torch._has_compatible_shallow_copy_type] - - def __new__(cls, requires_grad=True): - data = torch.Tensor() - return torch.Tensor._make_subclass(cls, data, requires_grad) + torch._has_compatible_shallow_copy_type, + ] def materialize(self, shape, device=None, dtype=None): - r"""Create a Parameter with the same properties of the uninitialized one. + r"""Create a Parameter or Tensor with the same properties of the uninitialized one. Given a shape, it materializes a parameter in the same device - and with the same `dtype` as the current one or the specified ones in the + and with the same `dtype` as the current one or the specified ones in the arguments. Args: @@ -98,35 +84,35 @@ class UninitializedParameter(Parameter): if dtype is None: dtype = self.data.dtype self.data = torch.empty(shape, device=device, dtype=dtype) - self.__class__ = Parameter + self.__class__ = self.cls_to_become @property def shape(self): raise RuntimeError( - 'Can\'t access the shape of an uninitialized parameter. ' + 'Can\'t access the shape of an uninitialized parameter or buffer. ' 'This error usually happens in `load_state_dict` when trying to load ' 'an uninitialized parameter into an initialized one. ' 'Call `forward` to initialize the parameters before accessing their attributes.') def share_memory_(self): raise RuntimeError( - 'Can\'t share memory on an uninitialized parameter. ' + 'Can\'t share memory on an uninitialized parameter or buffer. ' 'Call `forward` to initialize the parameters before calling ' '`module.share_memory()`.') def __repr__(self): - return 'Uninitialized parameter' + return f'<{self.__class__.__name__}>' def __reduce_ex__(self, proto): # See Note [Don't serialize hooks] return ( - UninitializedParameter, + self.__class__, (self.requires_grad,) ) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - # method-wrapper is to detect access to Tensor properties that are + # method-wrapper is to detect access to Tensor properties that are # wrapped in descriptors if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper': if kwargs is None: @@ -135,6 +121,50 @@ class UninitializedParameter(Parameter): raise ValueError( 'Attempted to use an uninitialized parameter in {}. ' 'This error happens when you are using a `LazyModule` or ' - 'explicitly manipulating `torch.nn.parameter.UninitializedParameter` ' - 'objects. When using LazyModules Call `forward` with a dummy batch ' - 'to initialize the parameters before calling torch functions'.format(func)) + 'explicitly manipulating `torch.nn.parameter.{}` ' + 'objects. When using LazyModules Call `forward` with a dummy batch ' + 'to initialize the parameters before calling torch functions'.format(func, cls.__name__)) + + +def is_lazy(param): + return isinstance(param, UninitializedTensorMixin) + + +class UninitializedParameter(UninitializedTensorMixin, Parameter): + r"""A parameter that is not initialized. + + Unitialized Parameters are a a special case of :class:`torch.nn.Parameter` + where the shape of the data is still unknown. + + Unlike a :class:`torch.nn.Parameter`, uninitialized parameters + hold no data and attempting to access some properties, like their shape, + will throw a runtime error. The only operations that can be performed on a uninitialized + parameter are changing its datatype, moving it to a different device and + converting it to a regular :class:`torch.nn.Parameter`. + """ + + cls_to_become = Parameter + + def __new__(cls, requires_grad=True): + data = torch.Tensor() + return torch.Tensor._make_subclass(cls, data, requires_grad) + + +class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): + r"""A buffer that is not initialized. + + Unitialized Buffer is a a special case of :class:`torch.Tensor` + where the shape of the data is still unknown. + + Unlike a :class:`torch.Tensor`, uninitialized parameters + hold no data and attempting to access some properties, like their shape, + will throw a runtime error. The only operations that can be performed on a uninitialized + parameter are changing its datatype, moving it to a different device and + converting it to a regular :class:`torch.Tensor`. + """ + + cls_to_become = torch.Tensor + + def __new__(cls, requires_grad=False): + data = torch.Tensor() + return torch.Tensor._make_subclass(cls, data, requires_grad) diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index 7e9e17eebf8..ff15a375d50 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -8,8 +8,16 @@ class Parameter(Tensor): ... +def is_lazy(param: Tensor): ... + class UninitializedParameter(Tensor): def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ... def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ... ... + +class UninitializedBuffer(Tensor): + def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ... + + def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ... + ...