Add LazyBatchNormXd (#51548)

Summary:
This PR implements UninitializedBuffer and LazyBatchnormXd based on https://github.com/pytorch/pytorch/issues/44538. (cc. emcastillo and albanD)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51548

Reviewed By: zhangguanheng66

Differential Revision: D26276903

Pulled By: albanD

fbshipit-source-id: 0ac706974178363f8af075e59b41d5989418922f
This commit is contained in:
Akifumi Imanishi 2021-02-05 10:23:22 -08:00 committed by Facebook GitHub Bot
parent 5a962369e2
commit aa1fd6b45a
10 changed files with 375 additions and 49 deletions

View File

@ -176,6 +176,9 @@ Normalization Layers
nn.BatchNorm1d
nn.BatchNorm2d
nn.BatchNorm3d
nn.LazyBatchNorm1d
nn.LazyBatchNorm2d
nn.LazyBatchNorm3d
nn.GroupNorm
nn.SyncBatchNorm
nn.InstanceNorm1d

View File

@ -30,7 +30,7 @@ from torch.nn.utils import clip_grad_norm_, clip_grad_value_
import torch.nn.utils.prune as prune
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.nn import Parameter
from torch.nn.parameter import UninitializedParameter
from torch.nn.parameter import UninitializedParameter, UninitializedBuffer
from torch.nn.parallel._functions import Broadcast
from torch.testing import get_all_fp_dtypes
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
@ -14276,7 +14276,37 @@ class TestLazyModules(TestCase):
self.assertTrue(module.has_uninitialized_params())
@suppress_warnings
def test_lazy_module_jit(self):
def test_lazy_module_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
state_dict = module.state_dict()
self.assertIsInstance(state_dict['test_buffer'], UninitializedBuffer)
new_module = LazyModule()
# An error is raised when there is an attempt to replace an existing parameter
# with an uninitialized one
new_module.register_buffer('test_buffer', torch.ones(5, 5))
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
new_module.load_state_dict(state_dict)
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
new_module = LazyModule()
new_module.register_buffer('test_buffer', torch.ones(5, 5))
module.load_state_dict(new_module.state_dict())
self.assertEqual(module.test_buffer, torch.ones((5, 5)))
# Uninitialized parameters are left unchanged
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
new_module = LazyModule()
new_module.register_buffer('test_buffer', UninitializedBuffer())
module.load_state_dict(new_module.state_dict())
module.load_state_dict(new_module.state_dict())
self.assertTrue(module.has_uninitialized_params())
@suppress_warnings
def test_lazy_module_jit_param(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
@ -14284,13 +14314,29 @@ class TestLazyModules(TestCase):
torch.jit.script(module)
@suppress_warnings
def test_lazy_share_memory(self):
def test_lazy_module_jit_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
torch.jit.script(module)
@suppress_warnings
def test_lazy_share_memory_param(self):
module = LazyModule()
module.register_parameter('test_param', UninitializedParameter())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
module.share_memory()
@suppress_warnings
def test_lazy_share_memory_buffer(self):
module = LazyModule()
module.register_buffer('test_buffer', UninitializedBuffer())
self.assertTrue(module.has_uninitialized_params())
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
module.share_memory()
@suppress_warnings
def test_linear(self):
module = nn.LazyLinear(10)
@ -14464,6 +14510,120 @@ class TestLazyModules(TestCase):
lambda: nn.LazyConvTranspose3d(32, 2),
(16, 32, 2, 2, 2))
def _check_lazy_batchnorm(self, cls, lazy_cls, input_shape):
for affine in [False, True]:
for track_running_stats in [False, True]:
lazy_module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
if affine:
self.assertIsInstance(lazy_module.weight, UninitializedParameter)
self.assertIsInstance(lazy_module.bias, UninitializedParameter)
if track_running_stats:
self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)
input = torch.ones(*input_shape)
y = lazy_module(input)
self.assertIsInstance(lazy_module, cls)
self.assertNotIsInstance(lazy_module, lazy_cls)
num_features = input_shape[1]
module = cls(num_features, affine=affine, track_running_stats=track_running_stats)
expected = module(input)
if module.weight is not None:
self.assertEqual(lazy_module.weight.shape, module.weight.shape)
self.assertEqual(lazy_module.weight, module.weight)
if module.bias is not None:
self.assertEqual(lazy_module.bias.shape, module.bias.shape)
self.assertEqual(lazy_module.bias, module.bias)
if module.running_mean is not None:
self.assertEqual(lazy_module.running_mean.shape, module.running_mean.shape)
self.assertEqual(lazy_module.running_mean, module.running_mean)
if module.running_var is not None:
self.assertEqual(lazy_module.running_var.shape, module.running_var.shape)
self.assertEqual(lazy_module.running_var, module.running_var)
if module.num_batches_tracked is not None:
self.assertEqual(lazy_module.num_batches_tracked.shape, module.num_batches_tracked.shape)
self.assertEqual(lazy_module.num_batches_tracked, module.num_batches_tracked)
def _check_lazy_batchnorm_pickle(self, cls, lazy_cls, input_shape):
for affine in [False, True]:
for track_running_stats in [False, True]:
module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
module = pickle.loads(pickle.dumps(module))
self.assertIsInstance(module, lazy_cls)
if affine:
self.assertIsInstance(module.weight, UninitializedParameter)
self.assertIsInstance(module.bias, UninitializedParameter)
if track_running_stats:
self.assertIsInstance(module.running_mean, UninitializedBuffer)
self.assertIsInstance(module.running_var, UninitializedBuffer)
input = torch.ones(*input_shape)
module(input) # fully materialized
module = pickle.loads(pickle.dumps(module))
self.assertNotIsInstance(module, lazy_cls)
self.assertIsInstance(module, cls)
if affine:
self.assertNotIsInstance(module.weight, UninitializedParameter)
self.assertNotIsInstance(module.bias, UninitializedParameter)
if track_running_stats:
self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
self.assertNotIsInstance(module.running_var, UninitializedBuffer)
def _check_lazy_batchnorm_state(self, cls, lazy_cls):
module = cls(10)
lazy_module = lazy_cls(affine=True, track_running_stats=True)
lazy_module.load_state_dict(module.state_dict())
# Parameters have been initialized but the module won't become a full
# Conv one until the first iteration. This is due to
# limitations on the state_dict loading logic
self.assertFalse(lazy_module.has_uninitialized_params())
self.assertEqual(lazy_module.weight.shape, (10,))
self.assertEqual(lazy_module.bias.shape, (10,))
self.assertEqual(lazy_module.running_mean.shape, (10,))
self.assertEqual(lazy_module.running_var.shape, (10,))
module = cls(10)
lazy_module = lazy_cls()
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
module.load_state_dict(lazy_module.state_dict())
def test_lazy_batchnorm1d(self):
self._check_lazy_batchnorm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
self._check_lazy_batchnorm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
def test_lazy_batchnorm1d_pickle(self):
self._check_lazy_batchnorm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
self._check_lazy_batchnorm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
def test_lazy_batchnorm1d_state(self):
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
def test_lazy_batchnorm2d(self):
self._check_lazy_batchnorm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
def test_lazy_batchnorm2d_pickle(self):
self._check_lazy_batchnorm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
def test_lazy_batchnorm2d_state(self):
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
def test_lazy_batchnorm3d(self):
self._check_lazy_batchnorm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
def test_lazy_batchnorm3d_pickle(self):
self._check_lazy_batchnorm_pickle(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
def test_lazy_batchnorm3d_state(self):
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
@suppress_warnings
def test_materialize_dtype(self):
module = LazyModule()

View File

@ -597,9 +597,13 @@ def check_module_initialized(mod):
# This is to avoid importing torch.distributed.nn
if not hasattr(mod, 'remote_parameters'):
for name, param in mod._parameters.items():
if isinstance(param, torch.nn.parameter.UninitializedParameter):
if torch.nn.parameter.is_lazy(param):
raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?"
.format(torch.typename(type(mod)), name))
for name, buf in mod._buffers.items():
if torch.nn.parameter.is_lazy(buf):
raise RuntimeError("'{}' has uninitialized buffers {}. Did you forget to run a forward pass?"
.format(torch.typename(type(mod)), name))
def infer_methods_to_compile(nn_module):
"""

View File

@ -1,5 +1,5 @@
from .modules import *
from .parameter import Parameter, UninitializedParameter
from .parameter import Parameter, UninitializedParameter, UninitializedBuffer
from .parallel import DataParallel
from . import init
from . import utils

View File

@ -15,7 +15,8 @@ from .container import Container, Sequential, ModuleList, ModuleDict, ParameterL
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
@ -58,5 +59,6 @@ __all__ = [
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d',
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
]

View File

@ -1,8 +1,9 @@
import torch
from torch import Tensor
from ._functions import SyncBatchNorm as sync_batch_norm
from .lazy import LazyModuleMixin
from .module import Module
from torch.nn.parameter import Parameter
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
from .. import functional as F
from .. import init
@ -140,6 +141,36 @@ class _BatchNorm(_NormBase):
self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
super(_LazyBatchNorm, self).__init__(
0, eps, momentum, affine, track_running_stats)
if self.affine:
self.weight = UninitializedParameter()
self.bias = UninitializedParameter()
if self.track_running_stats:
self.running_mean = UninitializedBuffer()
self.running_var = UninitializedBuffer()
def reset_parameters(self) -> None:
if not self.has_uninitialized_params() and self.num_features != 0:
super().reset_parameters()
def initialize_parameters(self, input) -> None: # type: ignore
if self.has_uninitialized_params():
self.num_features = input.shape[1]
if self.affine:
assert isinstance(self.weight, UninitializedParameter)
assert isinstance(self.bias, UninitializedParameter)
self.weight.materialize((self.num_features,))
self.bias.materialize((self.num_features,))
if self.track_running_stats:
self.running_mean.materialize((self.num_features,))
self.running_var.materialize((self.num_features,))
self.reset_parameters()
class BatchNorm1d(_BatchNorm):
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
inputs with optional additional channel dimension) as described in the paper
@ -213,6 +244,35 @@ class BatchNorm1d(_BatchNorm):
.format(input.dim()))
class LazyBatchNorm1d(_LazyBatchNorm):
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization of
the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
from the ``input.size(1)``.
Args:
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
"""
cls_to_become = BatchNorm1d # type: ignore[assignment]
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
class BatchNorm2d(_BatchNorm):
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
with additional channel dimension) as described in the paper
@ -286,6 +346,35 @@ class BatchNorm2d(_BatchNorm):
.format(input.dim()))
class LazyBatchNorm2d(_LazyBatchNorm):
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization of
the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
from the ``input.size(1)``.
Args:
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
"""
cls_to_become = BatchNorm2d # type: ignore[assignment]
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
class BatchNorm3d(_BatchNorm):
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper
@ -360,6 +449,35 @@ class BatchNorm3d(_BatchNorm):
.format(input.dim()))
class LazyBatchNorm3d(_LazyBatchNorm):
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization of
the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
from the ``input.size(1)``.
Args:
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics, and initializes statistics
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
When these buffers are ``None``, this module always uses batch statistics.
in both training and eval modes. Default: ``True``
"""
cls_to_become = BatchNorm3d # type: ignore[assignment]
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
class SyncBatchNorm(_BatchNorm):
r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
with additional channel dimension) as described in the paper

View File

@ -3,7 +3,7 @@ from typing_extensions import Protocol
import warnings
import torch
from ..parameter import UninitializedParameter
from ..parameter import is_lazy
class _LazyProtocol(Protocol):
@ -181,13 +181,14 @@ class LazyModuleMixin:
# which is not clean
for name, param in self._parameters.items():
if param is not None:
if isinstance(param, UninitializedParameter):
destination[prefix + name] = param
else:
destination[prefix + name] = param if keep_vars else param.detach()
if not (is_lazy(param) or keep_vars):
param = param.detach()
destination[prefix + name] = param
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
if not (is_lazy(buf) or keep_vars):
buf = buf.detach()
destination[prefix + name] = buf
def _lazy_load_hook(
self: _LazyProtocol, state_dict, prefix, local_metadata, strict,
@ -201,15 +202,14 @@ class LazyModuleMixin:
See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
for the details of the hook specification.
"""
local_state = {k: v for k, v in self._parameters.items() if v is not None}
for name, param in local_state.items():
for name, param in itertools.chain(self._parameters.items(), self._buffers.items()):
key = prefix + name
if key in state_dict:
if key in state_dict and param is not None:
input_param = state_dict[key]
if isinstance(param, UninitializedParameter):
if is_lazy(param):
# The current parameter is not initialized but the one being loaded one is
# create a new parameter based on the uninitialized one
if not isinstance(input_param, UninitializedParameter):
if not is_lazy(input_param):
with torch.no_grad():
param.materialize(input_param.shape)
@ -226,8 +226,9 @@ class LazyModuleMixin:
# This is to avoid the JIT to track this parameter and force
# custom modules __setstate__ to add it
params = self._parameters.values()
for param in itertools.chain(params):
if isinstance(param, (UninitializedParameter)):
buffers = self._buffers.values()
for param in itertools.chain(params, buffers):
if is_lazy(param):
return True
return False

View File

@ -1137,7 +1137,7 @@ class Module:
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = isinstance(param, torch.nn.parameter.UninitializedParameter)
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]

View File

@ -46,18 +46,7 @@ class Parameter(torch.Tensor):
__torch_function__ = _disabled_torch_function_impl
class UninitializedParameter(Parameter):
r"""A parameter that is not initialized.
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
where the shape of the data is still unknown.
Unlikely a :class:`torch.nn.Parameter`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.nn.Parameter`.
"""
class UninitializedTensorMixin:
_allowed_methods = [
torch.Tensor.__hash__,
torch.Tensor.size,
@ -74,14 +63,11 @@ class UninitializedParameter(Parameter):
torch.Tensor.cpu,
torch.Tensor.to,
torch.Tensor.get_device,
torch._has_compatible_shallow_copy_type]
def __new__(cls, requires_grad=True):
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)
torch._has_compatible_shallow_copy_type,
]
def materialize(self, shape, device=None, dtype=None):
r"""Create a Parameter with the same properties of the uninitialized one.
r"""Create a Parameter or Tensor with the same properties of the uninitialized one.
Given a shape, it materializes a parameter in the same device
and with the same `dtype` as the current one or the specified ones in the
arguments.
@ -98,29 +84,29 @@ class UninitializedParameter(Parameter):
if dtype is None:
dtype = self.data.dtype
self.data = torch.empty(shape, device=device, dtype=dtype)
self.__class__ = Parameter
self.__class__ = self.cls_to_become
@property
def shape(self):
raise RuntimeError(
'Can\'t access the shape of an uninitialized parameter. '
'Can\'t access the shape of an uninitialized parameter or buffer. '
'This error usually happens in `load_state_dict` when trying to load '
'an uninitialized parameter into an initialized one. '
'Call `forward` to initialize the parameters before accessing their attributes.')
def share_memory_(self):
raise RuntimeError(
'Can\'t share memory on an uninitialized parameter. '
'Can\'t share memory on an uninitialized parameter or buffer. '
'Call `forward` to initialize the parameters before calling '
'`module.share_memory()`.')
def __repr__(self):
return 'Uninitialized parameter'
return f'<{self.__class__.__name__}>'
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
UninitializedParameter,
self.__class__,
(self.requires_grad,)
)
@ -135,6 +121,50 @@ class UninitializedParameter(Parameter):
raise ValueError(
'Attempted to use an uninitialized parameter in {}. '
'This error happens when you are using a `LazyModule` or '
'explicitly manipulating `torch.nn.parameter.UninitializedParameter` '
'explicitly manipulating `torch.nn.parameter.{}` '
'objects. When using LazyModules Call `forward` with a dummy batch '
'to initialize the parameters before calling torch functions'.format(func))
'to initialize the parameters before calling torch functions'.format(func, cls.__name__))
def is_lazy(param):
return isinstance(param, UninitializedTensorMixin)
class UninitializedParameter(UninitializedTensorMixin, Parameter):
r"""A parameter that is not initialized.
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
where the shape of the data is still unknown.
Unlike a :class:`torch.nn.Parameter`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.nn.Parameter`.
"""
cls_to_become = Parameter
def __new__(cls, requires_grad=True):
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)
class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
r"""A buffer that is not initialized.
Unitialized Buffer is a a special case of :class:`torch.Tensor`
where the shape of the data is still unknown.
Unlike a :class:`torch.Tensor`, uninitialized parameters
hold no data and attempting to access some properties, like their shape,
will throw a runtime error. The only operations that can be performed on a uninitialized
parameter are changing its datatype, moving it to a different device and
converting it to a regular :class:`torch.Tensor`.
"""
cls_to_become = torch.Tensor
def __new__(cls, requires_grad=False):
data = torch.Tensor()
return torch.Tensor._make_subclass(cls, data, requires_grad)

View File

@ -8,8 +8,16 @@ class Parameter(Tensor):
...
def is_lazy(param: Tensor): ...
class UninitializedParameter(Tensor):
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
...
class UninitializedBuffer(Tensor):
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
...