Add LazyBatchNormXd (#51548)

Summary:
This PR implements UninitializedBuffer and LazyBatchnormXd based on https://github.com/pytorch/pytorch/issues/44538. (cc. emcastillo and albanD)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51548

Reviewed By: zhangguanheng66

Differential Revision: D26276903

Pulled By: albanD

fbshipit-source-id: 0ac706974178363f8af075e59b41d5989418922f
This commit is contained in:
Akifumi Imanishi 2021-02-05 10:23:22 -08:00 committed by Facebook GitHub Bot
parent 5a962369e2
commit aa1fd6b45a
10 changed files with 375 additions and 49 deletions

View File

@ -176,6 +176,9 @@ Normalization Layers
nn.BatchNorm1d nn.BatchNorm1d
nn.BatchNorm2d nn.BatchNorm2d
nn.BatchNorm3d nn.BatchNorm3d
nn.LazyBatchNorm1d
nn.LazyBatchNorm2d
nn.LazyBatchNorm3d
nn.GroupNorm nn.GroupNorm
nn.SyncBatchNorm nn.SyncBatchNorm
nn.InstanceNorm1d nn.InstanceNorm1d

View File

@ -30,7 +30,7 @@ from torch.nn.utils import clip_grad_norm_, clip_grad_value_
import torch.nn.utils.prune as prune import torch.nn.utils.prune as prune
from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.nn import Parameter from torch.nn import Parameter
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter, UninitializedBuffer
from torch.nn.parallel._functions import Broadcast from torch.nn.parallel._functions import Broadcast
from torch.testing import get_all_fp_dtypes from torch.testing import get_all_fp_dtypes
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
@ -14276,7 +14276,37 @@ class TestLazyModules(TestCase):
self.assertTrue(module.has_uninitialized_params()) self.assertTrue(module.has_uninitialized_params())
@suppress_warnings @suppress_warnings
def test_lazy_module_jit(self): def test_lazy_module_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
state_dict = module.state_dict()
self.assertIsInstance(state_dict['test_buffer'], UninitializedBuffer)
new_module = LazyModule()
# An error is raised when there is an attempt to replace an existing parameter
# with an uninitialized one
new_module.register_buffer('test_buffer', torch.ones(5, 5))
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
new_module.load_state_dict(state_dict)
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
new_module = LazyModule()
new_module.register_buffer('test_buffer', torch.ones(5, 5))
module.load_state_dict(new_module.state_dict())
self.assertEqual(module.test_buffer, torch.ones((5, 5)))
# Uninitialized parameters are left unchanged
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
new_module = LazyModule()
new_module.register_buffer('test_buffer', UninitializedBuffer())
module.load_state_dict(new_module.state_dict())
module.load_state_dict(new_module.state_dict())
self.assertTrue(module.has_uninitialized_params())
@suppress_warnings
def test_lazy_module_jit_param(self):
module = LazyModule() module = LazyModule()
module.register_parameter('test_param', UninitializedParameter()) module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params()) self.assertTrue(module.has_uninitialized_params())
@ -14284,13 +14314,29 @@ class TestLazyModules(TestCase):
torch.jit.script(module) torch.jit.script(module)
@suppress_warnings @suppress_warnings
def test_lazy_share_memory(self): def test_lazy_module_jit_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
torch.jit.script(module)
@suppress_warnings
def test_lazy_share_memory_param(self):
module = LazyModule() module = LazyModule()
module.register_parameter('test_param', UninitializedParameter()) module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params()) self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'): with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
module.share_memory() module.share_memory()
@suppress_warnings
def test_lazy_share_memory_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
module.share_memory()
@suppress_warnings @suppress_warnings
def test_linear(self): def test_linear(self):
module = nn.LazyLinear(10) module = nn.LazyLinear(10)
@ -14464,6 +14510,120 @@ class TestLazyModules(TestCase):
lambda: nn.LazyConvTranspose3d(32, 2), lambda: nn.LazyConvTranspose3d(32, 2),
(16, 32, 2, 2, 2)) (16, 32, 2, 2, 2))
def _check_lazy_batchnorm(self, cls, lazy_cls, input_shape):
for affine in [False, True]:
for track_running_stats in [False, True]:
lazy_module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
if affine:
self.assertIsInstance(lazy_module.weight, UninitializedParameter)
self.assertIsInstance(lazy_module.bias, UninitializedParameter)
if track_running_stats:
self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)
input = torch.ones(*input_shape)
y = lazy_module(input)
self.assertIsInstance(lazy_module, cls)
self.assertNotIsInstance(lazy_module, lazy_cls)
num_features = input_shape[1]
module = cls(num_features, affine=affine, track_running_stats=track_running_stats)
expected = module(input)
if module.weight is not None:
self.assertEqual(lazy_module.weight.shape, module.weight.shape)
self.assertEqual(lazy_module.weight, module.weight)
if module.bias is not None:
self.assertEqual(lazy_module.bias.shape, module.bias.shape)
self.assertEqual(lazy_module.bias, module.bias)
if module.running_mean is not None:
self.assertEqual(lazy_module.running_mean.shape, module.running_mean.shape)
self.assertEqual(lazy_module.running_mean, module.running_mean)
if module.running_var is not None:
self.assertEqual(lazy_module.running_var.shape, module.running_var.shape)
self.assertEqual(lazy_module.running_var, module.running_var)
if module.num_batches_tracked is not None:
self.assertEqual(lazy_module.num_batches_tracked.shape, module.num_batches_tracked.shape)
self.assertEqual(lazy_module.num_batches_tracked, module.num_batches_tracked)
def _check_lazy_batchnorm_pickle(self, cls, lazy_cls, input_shape):
for affine in [False, True]:
for track_running_stats in [False, True]:
module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
module = pickle.loads(pickle.dumps(module))
self.assertIsInstance(module, lazy_cls)
if affine:
self.assertIsInstance(module.weight, UninitializedParameter)
self.assertIsInstance(module.bias, UninitializedParameter)
if track_running_stats:
self.assertIsInstance(module.running_mean, UninitializedBuffer)
self.assertIsInstance(module.running_var, UninitializedBuffer)
input = torch.ones(*input_shape)
module(input) # fully materialized
module = pickle.loads(pickle.dumps(module))
self.assertNotIsInstance(module, lazy_cls)
self.assertIsInstance(module, cls)
if affine:
self.assertNotIsInstance(module.weight, UninitializedParameter)
self.assertNotIsInstance(module.bias, UninitializedParameter)
if track_running_stats:
self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
self.assertNotIsInstance(module.running_var, UninitializedBuffer)
def _check_lazy_batchnorm_state(self, cls, lazy_cls):
module = cls(10)
lazy_module = lazy_cls(affine=True, track_running_stats=True)
lazy_module.load_state_dict(module.state_dict())
# Parameters have been initialized but the module won't become a full
# Conv one until the first iteration. This is due to
# limitations on the state_dict loading logic
self.assertFalse(lazy_module.has_uninitialized_params())
self.assertEqual(lazy_module.weight.shape, (10,))
self.assertEqual(lazy_module.bias.shape, (10,))
self.assertEqual(lazy_module.running_mean.shape, (10,))
self.assertEqual(lazy_module.running_var.shape, (10,))
module = cls(10)
lazy_module = lazy_cls()
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
module.load_state_dict(lazy_module.state_dict())
def test_lazy_batchnorm1d(self):
self._check_lazy_batchnorm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
self._check_lazy_batchnorm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
def test_lazy_batchnorm1d_pickle(self):
self._check_lazy_batchnorm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
self._check_lazy_batchnorm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
def test_lazy_batchnorm1d_state(self):
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
def test_lazy_batchnorm2d(self):
self._check_lazy_batchnorm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
def test_lazy_batchnorm2d_pickle(self):
self._check_lazy_batchnorm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
def test_lazy_batchnorm2d_state(self):
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
def test_lazy_batchnorm3d(self):
self._check_lazy_batchnorm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
def test_lazy_batchnorm3d_pickle(self):
self._check_lazy_batchnorm_pickle(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
def test_lazy_batchnorm3d_state(self):
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
@suppress_warnings @suppress_warnings
def test_materialize_dtype(self): def test_materialize_dtype(self):
module = LazyModule() module = LazyModule()

View File

@ -597,9 +597,13 @@ def check_module_initialized(mod):
# This is to avoid importing torch.distributed.nn # This is to avoid importing torch.distributed.nn
if not hasattr(mod, 'remote_parameters'): if not hasattr(mod, 'remote_parameters'):
for name, param in mod._parameters.items(): for name, param in mod._parameters.items():
if isinstance(param, torch.nn.parameter.UninitializedParameter): if torch.nn.parameter.is_lazy(param):
raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?" raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?"
.format(torch.typename(type(mod)), name)) .format(torch.typename(type(mod)), name))
for name, buf in mod._buffers.items():
if torch.nn.parameter.is_lazy(buf):
raise RuntimeError("'{}' has uninitialized buffers {}. Did you forget to run a forward pass?"
.format(torch.typename(type(mod)), name))
def infer_methods_to_compile(nn_module): def infer_methods_to_compile(nn_module):
""" """

View File

@ -1,5 +1,5 @@
from .modules import * from .modules import *
from .parameter import Parameter, UninitializedParameter from .parameter import Parameter, UninitializedParameter, UninitializedBuffer
from .parallel import DataParallel from .parallel import DataParallel
from . import init from . import init
from . import utils from . import utils

View File

@ -15,7 +15,8 @@ from .container import Container, Sequential, ModuleList, ModuleDict, ParameterL
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \ from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \ MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
@ -58,5 +59,6 @@ __all__ = [
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d', 'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle' 'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
] ]

View File

@ -1,8 +1,9 @@
import torch import torch
from torch import Tensor from torch import Tensor
from ._functions import SyncBatchNorm as sync_batch_norm from ._functions import SyncBatchNorm as sync_batch_norm
from .lazy import LazyModuleMixin
from .module import Module from .module import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
from .. import functional as F from .. import functional as F
from .. import init from .. import init
@ -140,6 +141,36 @@ class _BatchNorm(_NormBase):
self.weight, self.bias, bn_training, exponential_average_factor, self.eps) self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
super(_LazyBatchNorm, self).__init__(
0, eps, momentum, affine, track_running_stats)
if self.affine:
self.weight = UninitializedParameter()
self.bias = UninitializedParameter()
if self.track_running_stats:
self.running_mean = UninitializedBuffer()
self.running_var = UninitializedBuffer()
def reset_parameters(self) -> None:
if not self.has_uninitialized_params() and self.num_features != 0:
super().reset_parameters()
def initialize_parameters(self, input) -> None: # type: ignore
if self.has_uninitialized_params():
self.num_features = input.shape[1]
if self.affine:
assert isinstance(self.weight, UninitializedParameter)
assert isinstance(self.bias, UninitializedParameter)
self.weight.materialize((self.num_features,))
self.bias.materialize((self.num_features,))
if self.track_running_stats:
self.running_mean.materialize((self.num_features,))
self.running_var.materialize((self.num_features,))
self.reset_parameters()
class BatchNorm1d(_BatchNorm): class BatchNorm1d(_BatchNorm):
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
inputs with optional additional channel dimension) as described in the paper inputs with optional additional channel dimension) as described in the paper
@ -213,6 +244,35 @@ class BatchNorm1d(_BatchNorm):
.format(input.dim())) .format(input.dim()))
class LazyBatchNorm1d(_LazyBatchNorm):
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization of
the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
from the ``input.size(1)``.
Args:
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
"""
cls_to_become = BatchNorm1d # type: ignore[assignment]
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
class BatchNorm2d(_BatchNorm): class BatchNorm2d(_BatchNorm):
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper with additional channel dimension) as described in the paper
@ -286,6 +346,35 @@ class BatchNorm2d(_BatchNorm):
.format(input.dim())) .format(input.dim()))
class LazyBatchNorm2d(_LazyBatchNorm):
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization of
the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
from the ``input.size(1)``.
Args:
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
"""
cls_to_become = BatchNorm2d # type: ignore[assignment]
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
class BatchNorm3d(_BatchNorm): class BatchNorm3d(_BatchNorm):
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper with additional channel dimension) as described in the paper
@ -360,6 +449,35 @@ class BatchNorm3d(_BatchNorm):
.format(input.dim())) .format(input.dim()))
class LazyBatchNorm3d(_LazyBatchNorm):
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization of
the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
from the ``input.size(1)``.
Args:
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
"""
cls_to_become = BatchNorm3d # type: ignore[assignment]
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
class SyncBatchNorm(_BatchNorm): class SyncBatchNorm(_BatchNorm):
r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
with additional channel dimension) as described in the paper with additional channel dimension) as described in the paper

View File

@ -3,7 +3,7 @@ from typing_extensions import Protocol
import warnings import warnings
import torch import torch
from ..parameter import UninitializedParameter from ..parameter import is_lazy
class _LazyProtocol(Protocol): class _LazyProtocol(Protocol):
@ -181,13 +181,14 @@ class LazyModuleMixin:
# which is not clean # which is not clean
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None: if param is not None:
if isinstance(param, UninitializedParameter): if not (is_lazy(param) or keep_vars):
destination[prefix + name] = param param = param.detach()
else: destination[prefix + name] = param
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in self._buffers.items(): for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set: if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach() if not (is_lazy(buf) or keep_vars):
buf = buf.detach()
destination[prefix + name] = buf
def _lazy_load_hook( def _lazy_load_hook(
self: _LazyProtocol, state_dict, prefix, local_metadata, strict, self: _LazyProtocol, state_dict, prefix, local_metadata, strict,
@ -201,15 +202,14 @@ class LazyModuleMixin:
See comment in ``torch.nn.Module._register_load_state_dict_pre_hook`` See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
for the details of the hook specification. for the details of the hook specification.
""" """
local_state = {k: v for k, v in self._parameters.items() if v is not None} for name, param in itertools.chain(self._parameters.items(), self._buffers.items()):
for name, param in local_state.items():
key = prefix + name key = prefix + name
if key in state_dict: if key in state_dict and param is not None:
input_param = state_dict[key] input_param = state_dict[key]
if isinstance(param, UninitializedParameter): if is_lazy(param):
# The current parameter is not initialized but the one being loaded one is # The current parameter is not initialized but the one being loaded one is
# create a new parameter based on the uninitialized one # create a new parameter based on the uninitialized one
if not isinstance(input_param, UninitializedParameter): if not is_lazy(input_param):
with torch.no_grad(): with torch.no_grad():
param.materialize(input_param.shape) param.materialize(input_param.shape)
@ -226,8 +226,9 @@ class LazyModuleMixin:
# This is to avoid the JIT to track this parameter and force # This is to avoid the JIT to track this parameter and force
# custom modules __setstate__ to add it # custom modules __setstate__ to add it
params = self._parameters.values() params = self._parameters.values()
for param in itertools.chain(params): buffers = self._buffers.values()
if isinstance(param, (UninitializedParameter)): for param in itertools.chain(params, buffers):
if is_lazy(param):
return True return True
return False return False

View File

@ -1137,7 +1137,7 @@ class Module:
# This is used to avoid copying uninitialized parameters into # This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks # non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute. # in such case, it will error when accessing the .shape attribute.
is_param_lazy = isinstance(param, torch.nn.parameter.UninitializedParameter) is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0] input_param = input_param[0]

View File

@ -46,18 +46,7 @@ class Parameter(torch.Tensor):
__torch_function__ = _disabled_torch_function_impl __torch_function__ = _disabled_torch_function_impl
class UninitializedParameter(Parameter): class UninitializedTensorMixin:
r"""A parameter that is not initialized.
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
where the shape of the data is still unknown.
Unlikely a :class:`torch.nn.Parameter`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.nn.Parameter`.
"""
_allowed_methods = [ _allowed_methods = [
torch.Tensor.__hash__, torch.Tensor.__hash__,
torch.Tensor.size, torch.Tensor.size,
@ -74,16 +63,13 @@ class UninitializedParameter(Parameter):
torch.Tensor.cpu, torch.Tensor.cpu,
torch.Tensor.to, torch.Tensor.to,
torch.Tensor.get_device, torch.Tensor.get_device,
torch._has_compatible_shallow_copy_type] torch._has_compatible_shallow_copy_type,
]
def __new__(cls, requires_grad=True):
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)
def materialize(self, shape, device=None, dtype=None): def materialize(self, shape, device=None, dtype=None):
r"""Create a Parameter with the same properties of the uninitialized one. r"""Create a Parameter or Tensor with the same properties of the uninitialized one.
Given a shape, it materializes a parameter in the same device Given a shape, it materializes a parameter in the same device
and with the same `dtype` as the current one or the specified ones in the and with the same `dtype` as the current one or the specified ones in the
arguments. arguments.
Args: Args:
@ -98,35 +84,35 @@ class UninitializedParameter(Parameter):
if dtype is None: if dtype is None:
dtype = self.data.dtype dtype = self.data.dtype
self.data = torch.empty(shape, device=device, dtype=dtype) self.data = torch.empty(shape, device=device, dtype=dtype)
self.__class__ = Parameter self.__class__ = self.cls_to_become
@property @property
def shape(self): def shape(self):
raise RuntimeError( raise RuntimeError(
'Can\'t access the shape of an uninitialized parameter. ' 'Can\'t access the shape of an uninitialized parameter or buffer. '
'This error usually happens in `load_state_dict` when trying to load ' 'This error usually happens in `load_state_dict` when trying to load '
'an uninitialized parameter into an initialized one. ' 'an uninitialized parameter into an initialized one. '
'Call `forward` to initialize the parameters before accessing their attributes.') 'Call `forward` to initialize the parameters before accessing their attributes.')
def share_memory_(self): def share_memory_(self):
raise RuntimeError( raise RuntimeError(
'Can\'t share memory on an uninitialized parameter. ' 'Can\'t share memory on an uninitialized parameter or buffer. '
'Call `forward` to initialize the parameters before calling ' 'Call `forward` to initialize the parameters before calling '
'`module.share_memory()`.') '`module.share_memory()`.')
def __repr__(self): def __repr__(self):
return 'Uninitialized parameter' return f'<{self.__class__.__name__}>'
def __reduce_ex__(self, proto): def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks] # See Note [Don't serialize hooks]
return ( return (
UninitializedParameter, self.__class__,
(self.requires_grad,) (self.requires_grad,)
) )
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
# method-wrapper is to detect access to Tensor properties that are # method-wrapper is to detect access to Tensor properties that are
# wrapped in descriptors # wrapped in descriptors
if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper': if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper':
if kwargs is None: if kwargs is None:
@ -135,6 +121,50 @@ class UninitializedParameter(Parameter):
raise ValueError( raise ValueError(
'Attempted to use an uninitialized parameter in {}. ' 'Attempted to use an uninitialized parameter in {}. '
'This error happens when you are using a `LazyModule` or ' 'This error happens when you are using a `LazyModule` or '
'explicitly manipulating `torch.nn.parameter.UninitializedParameter` ' 'explicitly manipulating `torch.nn.parameter.{}` '
'objects. When using LazyModules Call `forward` with a dummy batch ' 'objects. When using LazyModules Call `forward` with a dummy batch '
'to initialize the parameters before calling torch functions'.format(func)) 'to initialize the parameters before calling torch functions'.format(func, cls.__name__))
def is_lazy(param):
return isinstance(param, UninitializedTensorMixin)
class UninitializedParameter(UninitializedTensorMixin, Parameter):
r"""A parameter that is not initialized.
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
where the shape of the data is still unknown.
Unlike a :class:`torch.nn.Parameter`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.nn.Parameter`.
"""
cls_to_become = Parameter
def __new__(cls, requires_grad=True):
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)
class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
r"""A buffer that is not initialized.
Unitialized Buffer is a a special case of :class:`torch.Tensor`
where the shape of the data is still unknown.
Unlike a :class:`torch.Tensor`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.Tensor`.
"""
cls_to_become = torch.Tensor
def __new__(cls, requires_grad=False):
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)

View File

@ -8,8 +8,16 @@ class Parameter(Tensor):
... ...
def is_lazy(param: Tensor): ...
class UninitializedParameter(Tensor): class UninitializedParameter(Tensor):
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ... def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ... def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
... ...
class UninitializedBuffer(Tensor):
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
...