mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Add LazyBatchNormXd (#51548)
Summary: This PR implements UninitializedBuffer and LazyBatchnormXd based on https://github.com/pytorch/pytorch/issues/44538. (cc. emcastillo and albanD) Pull Request resolved: https://github.com/pytorch/pytorch/pull/51548 Reviewed By: zhangguanheng66 Differential Revision: D26276903 Pulled By: albanD fbshipit-source-id: 0ac706974178363f8af075e59b41d5989418922f
This commit is contained in:
parent
5a962369e2
commit
aa1fd6b45a
|
|
@ -176,6 +176,9 @@ Normalization Layers
|
||||||
nn.BatchNorm1d
|
nn.BatchNorm1d
|
||||||
nn.BatchNorm2d
|
nn.BatchNorm2d
|
||||||
nn.BatchNorm3d
|
nn.BatchNorm3d
|
||||||
|
nn.LazyBatchNorm1d
|
||||||
|
nn.LazyBatchNorm2d
|
||||||
|
nn.LazyBatchNorm3d
|
||||||
nn.GroupNorm
|
nn.GroupNorm
|
||||||
nn.SyncBatchNorm
|
nn.SyncBatchNorm
|
||||||
nn.InstanceNorm1d
|
nn.InstanceNorm1d
|
||||||
|
|
|
||||||
166
test/test_nn.py
166
test/test_nn.py
|
|
@ -30,7 +30,7 @@ from torch.nn.utils import clip_grad_norm_, clip_grad_value_
|
||||||
import torch.nn.utils.prune as prune
|
import torch.nn.utils.prune as prune
|
||||||
from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from torch.nn.parameter import UninitializedParameter
|
from torch.nn.parameter import UninitializedParameter, UninitializedBuffer
|
||||||
from torch.nn.parallel._functions import Broadcast
|
from torch.nn.parallel._functions import Broadcast
|
||||||
from torch.testing import get_all_fp_dtypes
|
from torch.testing import get_all_fp_dtypes
|
||||||
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
||||||
|
|
@ -14276,7 +14276,37 @@ class TestLazyModules(TestCase):
|
||||||
self.assertTrue(module.has_uninitialized_params())
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
|
||||||
@suppress_warnings
|
@suppress_warnings
|
||||||
def test_lazy_module_jit(self):
|
def test_lazy_module_buffer(self):
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_buffer('test_buffer', UninitializedBuffer())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
state_dict = module.state_dict()
|
||||||
|
self.assertIsInstance(state_dict['test_buffer'], UninitializedBuffer)
|
||||||
|
new_module = LazyModule()
|
||||||
|
# An error is raised when there is an attempt to replace an existing parameter
|
||||||
|
# with an uninitialized one
|
||||||
|
new_module.register_buffer('test_buffer', torch.ones(5, 5))
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
||||||
|
new_module.load_state_dict(state_dict)
|
||||||
|
# Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
|
||||||
|
new_module = LazyModule()
|
||||||
|
new_module.register_buffer('test_buffer', torch.ones(5, 5))
|
||||||
|
module.load_state_dict(new_module.state_dict())
|
||||||
|
self.assertEqual(module.test_buffer, torch.ones((5, 5)))
|
||||||
|
|
||||||
|
# Uninitialized parameters are left unchanged
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_buffer('test_buffer', UninitializedBuffer())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
|
||||||
|
new_module = LazyModule()
|
||||||
|
new_module.register_buffer('test_buffer', UninitializedBuffer())
|
||||||
|
module.load_state_dict(new_module.state_dict())
|
||||||
|
module.load_state_dict(new_module.state_dict())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_lazy_module_jit_param(self):
|
||||||
module = LazyModule()
|
module = LazyModule()
|
||||||
module.register_parameter('test_param', UninitializedParameter())
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
self.assertTrue(module.has_uninitialized_params())
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
|
@ -14284,13 +14314,29 @@ class TestLazyModules(TestCase):
|
||||||
torch.jit.script(module)
|
torch.jit.script(module)
|
||||||
|
|
||||||
@suppress_warnings
|
@suppress_warnings
|
||||||
def test_lazy_share_memory(self):
|
def test_lazy_module_jit_buffer(self):
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_buffer('test_buffer', UninitializedBuffer())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'run a forward pass'):
|
||||||
|
torch.jit.script(module)
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_lazy_share_memory_param(self):
|
||||||
module = LazyModule()
|
module = LazyModule()
|
||||||
module.register_parameter('test_param', UninitializedParameter())
|
module.register_parameter('test_param', UninitializedParameter())
|
||||||
self.assertTrue(module.has_uninitialized_params())
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
|
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
|
||||||
module.share_memory()
|
module.share_memory()
|
||||||
|
|
||||||
|
@suppress_warnings
|
||||||
|
def test_lazy_share_memory_buffer(self):
|
||||||
|
module = LazyModule()
|
||||||
|
module.register_buffer('test_buffer', UninitializedBuffer())
|
||||||
|
self.assertTrue(module.has_uninitialized_params())
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'):
|
||||||
|
module.share_memory()
|
||||||
|
|
||||||
@suppress_warnings
|
@suppress_warnings
|
||||||
def test_linear(self):
|
def test_linear(self):
|
||||||
module = nn.LazyLinear(10)
|
module = nn.LazyLinear(10)
|
||||||
|
|
@ -14464,6 +14510,120 @@ class TestLazyModules(TestCase):
|
||||||
lambda: nn.LazyConvTranspose3d(32, 2),
|
lambda: nn.LazyConvTranspose3d(32, 2),
|
||||||
(16, 32, 2, 2, 2))
|
(16, 32, 2, 2, 2))
|
||||||
|
|
||||||
|
def _check_lazy_batchnorm(self, cls, lazy_cls, input_shape):
|
||||||
|
for affine in [False, True]:
|
||||||
|
for track_running_stats in [False, True]:
|
||||||
|
lazy_module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
|
||||||
|
|
||||||
|
if affine:
|
||||||
|
self.assertIsInstance(lazy_module.weight, UninitializedParameter)
|
||||||
|
self.assertIsInstance(lazy_module.bias, UninitializedParameter)
|
||||||
|
if track_running_stats:
|
||||||
|
self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
|
||||||
|
self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)
|
||||||
|
|
||||||
|
input = torch.ones(*input_shape)
|
||||||
|
y = lazy_module(input)
|
||||||
|
self.assertIsInstance(lazy_module, cls)
|
||||||
|
self.assertNotIsInstance(lazy_module, lazy_cls)
|
||||||
|
|
||||||
|
num_features = input_shape[1]
|
||||||
|
module = cls(num_features, affine=affine, track_running_stats=track_running_stats)
|
||||||
|
expected = module(input)
|
||||||
|
|
||||||
|
if module.weight is not None:
|
||||||
|
self.assertEqual(lazy_module.weight.shape, module.weight.shape)
|
||||||
|
self.assertEqual(lazy_module.weight, module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
self.assertEqual(lazy_module.bias.shape, module.bias.shape)
|
||||||
|
self.assertEqual(lazy_module.bias, module.bias)
|
||||||
|
if module.running_mean is not None:
|
||||||
|
self.assertEqual(lazy_module.running_mean.shape, module.running_mean.shape)
|
||||||
|
self.assertEqual(lazy_module.running_mean, module.running_mean)
|
||||||
|
if module.running_var is not None:
|
||||||
|
self.assertEqual(lazy_module.running_var.shape, module.running_var.shape)
|
||||||
|
self.assertEqual(lazy_module.running_var, module.running_var)
|
||||||
|
if module.num_batches_tracked is not None:
|
||||||
|
self.assertEqual(lazy_module.num_batches_tracked.shape, module.num_batches_tracked.shape)
|
||||||
|
self.assertEqual(lazy_module.num_batches_tracked, module.num_batches_tracked)
|
||||||
|
|
||||||
|
def _check_lazy_batchnorm_pickle(self, cls, lazy_cls, input_shape):
|
||||||
|
for affine in [False, True]:
|
||||||
|
for track_running_stats in [False, True]:
|
||||||
|
module = lazy_cls(affine=affine, track_running_stats=track_running_stats)
|
||||||
|
module = pickle.loads(pickle.dumps(module))
|
||||||
|
|
||||||
|
self.assertIsInstance(module, lazy_cls)
|
||||||
|
if affine:
|
||||||
|
self.assertIsInstance(module.weight, UninitializedParameter)
|
||||||
|
self.assertIsInstance(module.bias, UninitializedParameter)
|
||||||
|
if track_running_stats:
|
||||||
|
self.assertIsInstance(module.running_mean, UninitializedBuffer)
|
||||||
|
self.assertIsInstance(module.running_var, UninitializedBuffer)
|
||||||
|
|
||||||
|
input = torch.ones(*input_shape)
|
||||||
|
module(input) # fully materialized
|
||||||
|
module = pickle.loads(pickle.dumps(module))
|
||||||
|
|
||||||
|
self.assertNotIsInstance(module, lazy_cls)
|
||||||
|
self.assertIsInstance(module, cls)
|
||||||
|
if affine:
|
||||||
|
self.assertNotIsInstance(module.weight, UninitializedParameter)
|
||||||
|
self.assertNotIsInstance(module.bias, UninitializedParameter)
|
||||||
|
if track_running_stats:
|
||||||
|
self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
|
||||||
|
self.assertNotIsInstance(module.running_var, UninitializedBuffer)
|
||||||
|
|
||||||
|
def _check_lazy_batchnorm_state(self, cls, lazy_cls):
|
||||||
|
module = cls(10)
|
||||||
|
lazy_module = lazy_cls(affine=True, track_running_stats=True)
|
||||||
|
lazy_module.load_state_dict(module.state_dict())
|
||||||
|
# Parameters have been initialized but the module won't become a full
|
||||||
|
# Conv one until the first iteration. This is due to
|
||||||
|
# limitations on the state_dict loading logic
|
||||||
|
self.assertFalse(lazy_module.has_uninitialized_params())
|
||||||
|
self.assertEqual(lazy_module.weight.shape, (10,))
|
||||||
|
self.assertEqual(lazy_module.bias.shape, (10,))
|
||||||
|
self.assertEqual(lazy_module.running_mean.shape, (10,))
|
||||||
|
self.assertEqual(lazy_module.running_var.shape, (10,))
|
||||||
|
|
||||||
|
module = cls(10)
|
||||||
|
lazy_module = lazy_cls()
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'):
|
||||||
|
module.load_state_dict(lazy_module.state_dict())
|
||||||
|
|
||||||
|
def test_lazy_batchnorm1d(self):
|
||||||
|
self._check_lazy_batchnorm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
|
||||||
|
self._check_lazy_batchnorm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
|
||||||
|
|
||||||
|
def test_lazy_batchnorm1d_pickle(self):
|
||||||
|
self._check_lazy_batchnorm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
|
||||||
|
self._check_lazy_batchnorm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))
|
||||||
|
|
||||||
|
def test_lazy_batchnorm1d_state(self):
|
||||||
|
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
|
||||||
|
self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
|
||||||
|
|
||||||
|
def test_lazy_batchnorm2d(self):
|
||||||
|
self._check_lazy_batchnorm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
|
||||||
|
|
||||||
|
def test_lazy_batchnorm2d_pickle(self):
|
||||||
|
self._check_lazy_batchnorm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))
|
||||||
|
|
||||||
|
def test_lazy_batchnorm2d_state(self):
|
||||||
|
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
|
||||||
|
self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
|
||||||
|
|
||||||
|
def test_lazy_batchnorm3d(self):
|
||||||
|
self._check_lazy_batchnorm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
|
||||||
|
|
||||||
|
def test_lazy_batchnorm3d_pickle(self):
|
||||||
|
self._check_lazy_batchnorm_pickle(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))
|
||||||
|
|
||||||
|
def test_lazy_batchnorm3d_state(self):
|
||||||
|
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
|
||||||
|
self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
|
||||||
|
|
||||||
@suppress_warnings
|
@suppress_warnings
|
||||||
def test_materialize_dtype(self):
|
def test_materialize_dtype(self):
|
||||||
module = LazyModule()
|
module = LazyModule()
|
||||||
|
|
|
||||||
|
|
@ -597,9 +597,13 @@ def check_module_initialized(mod):
|
||||||
# This is to avoid importing torch.distributed.nn
|
# This is to avoid importing torch.distributed.nn
|
||||||
if not hasattr(mod, 'remote_parameters'):
|
if not hasattr(mod, 'remote_parameters'):
|
||||||
for name, param in mod._parameters.items():
|
for name, param in mod._parameters.items():
|
||||||
if isinstance(param, torch.nn.parameter.UninitializedParameter):
|
if torch.nn.parameter.is_lazy(param):
|
||||||
raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?"
|
raise RuntimeError("'{}' has uninitialized parameters {}. Did you forget to run a forward pass?"
|
||||||
.format(torch.typename(type(mod)), name))
|
.format(torch.typename(type(mod)), name))
|
||||||
|
for name, buf in mod._buffers.items():
|
||||||
|
if torch.nn.parameter.is_lazy(buf):
|
||||||
|
raise RuntimeError("'{}' has uninitialized buffers {}. Did you forget to run a forward pass?"
|
||||||
|
.format(torch.typename(type(mod)), name))
|
||||||
|
|
||||||
def infer_methods_to_compile(nn_module):
|
def infer_methods_to_compile(nn_module):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from .modules import *
|
from .modules import *
|
||||||
from .parameter import Parameter, UninitializedParameter
|
from .parameter import Parameter, UninitializedParameter, UninitializedBuffer
|
||||||
from .parallel import DataParallel
|
from .parallel import DataParallel
|
||||||
from . import init
|
from . import init
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,8 @@ from .container import Container, Sequential, ModuleList, ModuleDict, ParameterL
|
||||||
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
|
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
|
||||||
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
|
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
|
||||||
AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
|
AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
|
||||||
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm
|
from .batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, SyncBatchNorm, \
|
||||||
|
LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d
|
||||||
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
|
from .instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
|
||||||
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
|
from .normalization import LocalResponseNorm, CrossMapLRN2d, LayerNorm, GroupNorm
|
||||||
from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
|
from .dropout import Dropout, Dropout2d, Dropout3d, AlphaDropout, FeatureAlphaDropout
|
||||||
|
|
@ -58,5 +59,6 @@ __all__ = [
|
||||||
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
|
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
|
||||||
'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
|
'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
|
||||||
'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
|
'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
|
||||||
|
'LazyBatchNorm1d', 'LazyBatchNorm2d', 'LazyBatchNorm3d',
|
||||||
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
|
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle'
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from ._functions import SyncBatchNorm as sync_batch_norm
|
from ._functions import SyncBatchNorm as sync_batch_norm
|
||||||
|
from .lazy import LazyModuleMixin
|
||||||
from .module import Module
|
from .module import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import init
|
from .. import init
|
||||||
|
|
||||||
|
|
@ -140,6 +141,36 @@ class _BatchNorm(_NormBase):
|
||||||
self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
|
self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class _LazyBatchNorm(LazyModuleMixin, _BatchNorm):
|
||||||
|
|
||||||
|
def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
|
||||||
|
super(_LazyBatchNorm, self).__init__(
|
||||||
|
0, eps, momentum, affine, track_running_stats)
|
||||||
|
if self.affine:
|
||||||
|
self.weight = UninitializedParameter()
|
||||||
|
self.bias = UninitializedParameter()
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.running_mean = UninitializedBuffer()
|
||||||
|
self.running_var = UninitializedBuffer()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
if not self.has_uninitialized_params() and self.num_features != 0:
|
||||||
|
super().reset_parameters()
|
||||||
|
|
||||||
|
def initialize_parameters(self, input) -> None: # type: ignore
|
||||||
|
if self.has_uninitialized_params():
|
||||||
|
self.num_features = input.shape[1]
|
||||||
|
if self.affine:
|
||||||
|
assert isinstance(self.weight, UninitializedParameter)
|
||||||
|
assert isinstance(self.bias, UninitializedParameter)
|
||||||
|
self.weight.materialize((self.num_features,))
|
||||||
|
self.bias.materialize((self.num_features,))
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.running_mean.materialize((self.num_features,))
|
||||||
|
self.running_var.materialize((self.num_features,))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm1d(_BatchNorm):
|
class BatchNorm1d(_BatchNorm):
|
||||||
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
|
r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
|
||||||
inputs with optional additional channel dimension) as described in the paper
|
inputs with optional additional channel dimension) as described in the paper
|
||||||
|
|
@ -213,6 +244,35 @@ class BatchNorm1d(_BatchNorm):
|
||||||
.format(input.dim()))
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
|
class LazyBatchNorm1d(_LazyBatchNorm):
|
||||||
|
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization of
|
||||||
|
the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
|
||||||
|
from the ``input.size(1)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eps: a value added to the denominator for numerical stability.
|
||||||
|
Default: 1e-5
|
||||||
|
momentum: the value used for the running_mean and running_var
|
||||||
|
computation. Can be set to ``None`` for cumulative moving average
|
||||||
|
(i.e. simple average). Default: 0.1
|
||||||
|
affine: a boolean value that when set to ``True``, this module has
|
||||||
|
learnable affine parameters. Default: ``True``
|
||||||
|
track_running_stats: a boolean value that when set to ``True``, this
|
||||||
|
module tracks the running mean and variance, and when set to ``False``,
|
||||||
|
this module does not track such statistics, and initializes statistics
|
||||||
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
||||||
|
When these buffers are ``None``, this module always uses batch statistics.
|
||||||
|
in both training and eval modes. Default: ``True``
|
||||||
|
"""
|
||||||
|
|
||||||
|
cls_to_become = BatchNorm1d # type: ignore[assignment]
|
||||||
|
|
||||||
|
def _check_input_dim(self, input):
|
||||||
|
if input.dim() != 2 and input.dim() != 3:
|
||||||
|
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||||
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm2d(_BatchNorm):
|
class BatchNorm2d(_BatchNorm):
|
||||||
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
|
||||||
with additional channel dimension) as described in the paper
|
with additional channel dimension) as described in the paper
|
||||||
|
|
@ -286,6 +346,35 @@ class BatchNorm2d(_BatchNorm):
|
||||||
.format(input.dim()))
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
|
class LazyBatchNorm2d(_LazyBatchNorm):
|
||||||
|
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization of
|
||||||
|
the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
|
||||||
|
from the ``input.size(1)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eps: a value added to the denominator for numerical stability.
|
||||||
|
Default: 1e-5
|
||||||
|
momentum: the value used for the running_mean and running_var
|
||||||
|
computation. Can be set to ``None`` for cumulative moving average
|
||||||
|
(i.e. simple average). Default: 0.1
|
||||||
|
affine: a boolean value that when set to ``True``, this module has
|
||||||
|
learnable affine parameters. Default: ``True``
|
||||||
|
track_running_stats: a boolean value that when set to ``True``, this
|
||||||
|
module tracks the running mean and variance, and when set to ``False``,
|
||||||
|
this module does not track such statistics, and initializes statistics
|
||||||
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
||||||
|
When these buffers are ``None``, this module always uses batch statistics.
|
||||||
|
in both training and eval modes. Default: ``True``
|
||||||
|
"""
|
||||||
|
|
||||||
|
cls_to_become = BatchNorm2d # type: ignore[assignment]
|
||||||
|
|
||||||
|
def _check_input_dim(self, input):
|
||||||
|
if input.dim() != 4:
|
||||||
|
raise ValueError('expected 4D input (got {}D input)'
|
||||||
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm3d(_BatchNorm):
|
class BatchNorm3d(_BatchNorm):
|
||||||
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
||||||
with additional channel dimension) as described in the paper
|
with additional channel dimension) as described in the paper
|
||||||
|
|
@ -360,6 +449,35 @@ class BatchNorm3d(_BatchNorm):
|
||||||
.format(input.dim()))
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
|
class LazyBatchNorm3d(_LazyBatchNorm):
|
||||||
|
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization of
|
||||||
|
the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
|
||||||
|
from the ``input.size(1)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eps: a value added to the denominator for numerical stability.
|
||||||
|
Default: 1e-5
|
||||||
|
momentum: the value used for the running_mean and running_var
|
||||||
|
computation. Can be set to ``None`` for cumulative moving average
|
||||||
|
(i.e. simple average). Default: 0.1
|
||||||
|
affine: a boolean value that when set to ``True``, this module has
|
||||||
|
learnable affine parameters. Default: ``True``
|
||||||
|
track_running_stats: a boolean value that when set to ``True``, this
|
||||||
|
module tracks the running mean and variance, and when set to ``False``,
|
||||||
|
this module does not track such statistics, and initializes statistics
|
||||||
|
buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
|
||||||
|
When these buffers are ``None``, this module always uses batch statistics.
|
||||||
|
in both training and eval modes. Default: ``True``
|
||||||
|
"""
|
||||||
|
|
||||||
|
cls_to_become = BatchNorm3d # type: ignore[assignment]
|
||||||
|
|
||||||
|
def _check_input_dim(self, input):
|
||||||
|
if input.dim() != 5:
|
||||||
|
raise ValueError('expected 5D input (got {}D input)'
|
||||||
|
.format(input.dim()))
|
||||||
|
|
||||||
|
|
||||||
class SyncBatchNorm(_BatchNorm):
|
class SyncBatchNorm(_BatchNorm):
|
||||||
r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
|
r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs
|
||||||
with additional channel dimension) as described in the paper
|
with additional channel dimension) as described in the paper
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing_extensions import Protocol
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from ..parameter import UninitializedParameter
|
from ..parameter import is_lazy
|
||||||
|
|
||||||
|
|
||||||
class _LazyProtocol(Protocol):
|
class _LazyProtocol(Protocol):
|
||||||
|
|
@ -181,13 +181,14 @@ class LazyModuleMixin:
|
||||||
# which is not clean
|
# which is not clean
|
||||||
for name, param in self._parameters.items():
|
for name, param in self._parameters.items():
|
||||||
if param is not None:
|
if param is not None:
|
||||||
if isinstance(param, UninitializedParameter):
|
if not (is_lazy(param) or keep_vars):
|
||||||
destination[prefix + name] = param
|
param = param.detach()
|
||||||
else:
|
destination[prefix + name] = param
|
||||||
destination[prefix + name] = param if keep_vars else param.detach()
|
|
||||||
for name, buf in self._buffers.items():
|
for name, buf in self._buffers.items():
|
||||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
if not (is_lazy(buf) or keep_vars):
|
||||||
|
buf = buf.detach()
|
||||||
|
destination[prefix + name] = buf
|
||||||
|
|
||||||
def _lazy_load_hook(
|
def _lazy_load_hook(
|
||||||
self: _LazyProtocol, state_dict, prefix, local_metadata, strict,
|
self: _LazyProtocol, state_dict, prefix, local_metadata, strict,
|
||||||
|
|
@ -201,15 +202,14 @@ class LazyModuleMixin:
|
||||||
See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
|
See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
|
||||||
for the details of the hook specification.
|
for the details of the hook specification.
|
||||||
"""
|
"""
|
||||||
local_state = {k: v for k, v in self._parameters.items() if v is not None}
|
for name, param in itertools.chain(self._parameters.items(), self._buffers.items()):
|
||||||
for name, param in local_state.items():
|
|
||||||
key = prefix + name
|
key = prefix + name
|
||||||
if key in state_dict:
|
if key in state_dict and param is not None:
|
||||||
input_param = state_dict[key]
|
input_param = state_dict[key]
|
||||||
if isinstance(param, UninitializedParameter):
|
if is_lazy(param):
|
||||||
# The current parameter is not initialized but the one being loaded one is
|
# The current parameter is not initialized but the one being loaded one is
|
||||||
# create a new parameter based on the uninitialized one
|
# create a new parameter based on the uninitialized one
|
||||||
if not isinstance(input_param, UninitializedParameter):
|
if not is_lazy(input_param):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
param.materialize(input_param.shape)
|
param.materialize(input_param.shape)
|
||||||
|
|
||||||
|
|
@ -226,8 +226,9 @@ class LazyModuleMixin:
|
||||||
# This is to avoid the JIT to track this parameter and force
|
# This is to avoid the JIT to track this parameter and force
|
||||||
# custom modules __setstate__ to add it
|
# custom modules __setstate__ to add it
|
||||||
params = self._parameters.values()
|
params = self._parameters.values()
|
||||||
for param in itertools.chain(params):
|
buffers = self._buffers.values()
|
||||||
if isinstance(param, (UninitializedParameter)):
|
for param in itertools.chain(params, buffers):
|
||||||
|
if is_lazy(param):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1137,7 +1137,7 @@ class Module:
|
||||||
# This is used to avoid copying uninitialized parameters into
|
# This is used to avoid copying uninitialized parameters into
|
||||||
# non-lazy modules, since they dont have the hook to do the checks
|
# non-lazy modules, since they dont have the hook to do the checks
|
||||||
# in such case, it will error when accessing the .shape attribute.
|
# in such case, it will error when accessing the .shape attribute.
|
||||||
is_param_lazy = isinstance(param, torch.nn.parameter.UninitializedParameter)
|
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||||
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||||
input_param = input_param[0]
|
input_param = input_param[0]
|
||||||
|
|
|
||||||
|
|
@ -46,18 +46,7 @@ class Parameter(torch.Tensor):
|
||||||
__torch_function__ = _disabled_torch_function_impl
|
__torch_function__ = _disabled_torch_function_impl
|
||||||
|
|
||||||
|
|
||||||
class UninitializedParameter(Parameter):
|
class UninitializedTensorMixin:
|
||||||
r"""A parameter that is not initialized.
|
|
||||||
|
|
||||||
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
|
|
||||||
where the shape of the data is still unknown.
|
|
||||||
|
|
||||||
Unlikely a :class:`torch.nn.Parameter`, uninitialized parameters
|
|
||||||
hold no data and attempting to access some properties, like their shape,
|
|
||||||
will throw a runtime error. The only operations that can be performed on a uninitialized
|
|
||||||
parameter are changing its datatype, moving it to a different device and
|
|
||||||
converting it to a regular :class:`torch.nn.Parameter`.
|
|
||||||
"""
|
|
||||||
_allowed_methods = [
|
_allowed_methods = [
|
||||||
torch.Tensor.__hash__,
|
torch.Tensor.__hash__,
|
||||||
torch.Tensor.size,
|
torch.Tensor.size,
|
||||||
|
|
@ -74,16 +63,13 @@ class UninitializedParameter(Parameter):
|
||||||
torch.Tensor.cpu,
|
torch.Tensor.cpu,
|
||||||
torch.Tensor.to,
|
torch.Tensor.to,
|
||||||
torch.Tensor.get_device,
|
torch.Tensor.get_device,
|
||||||
torch._has_compatible_shallow_copy_type]
|
torch._has_compatible_shallow_copy_type,
|
||||||
|
]
|
||||||
def __new__(cls, requires_grad=True):
|
|
||||||
data = torch.Tensor()
|
|
||||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
|
||||||
|
|
||||||
def materialize(self, shape, device=None, dtype=None):
|
def materialize(self, shape, device=None, dtype=None):
|
||||||
r"""Create a Parameter with the same properties of the uninitialized one.
|
r"""Create a Parameter or Tensor with the same properties of the uninitialized one.
|
||||||
Given a shape, it materializes a parameter in the same device
|
Given a shape, it materializes a parameter in the same device
|
||||||
and with the same `dtype` as the current one or the specified ones in the
|
and with the same `dtype` as the current one or the specified ones in the
|
||||||
arguments.
|
arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -98,35 +84,35 @@ class UninitializedParameter(Parameter):
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = self.data.dtype
|
dtype = self.data.dtype
|
||||||
self.data = torch.empty(shape, device=device, dtype=dtype)
|
self.data = torch.empty(shape, device=device, dtype=dtype)
|
||||||
self.__class__ = Parameter
|
self.__class__ = self.cls_to_become
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Can\'t access the shape of an uninitialized parameter. '
|
'Can\'t access the shape of an uninitialized parameter or buffer. '
|
||||||
'This error usually happens in `load_state_dict` when trying to load '
|
'This error usually happens in `load_state_dict` when trying to load '
|
||||||
'an uninitialized parameter into an initialized one. '
|
'an uninitialized parameter into an initialized one. '
|
||||||
'Call `forward` to initialize the parameters before accessing their attributes.')
|
'Call `forward` to initialize the parameters before accessing their attributes.')
|
||||||
|
|
||||||
def share_memory_(self):
|
def share_memory_(self):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Can\'t share memory on an uninitialized parameter. '
|
'Can\'t share memory on an uninitialized parameter or buffer. '
|
||||||
'Call `forward` to initialize the parameters before calling '
|
'Call `forward` to initialize the parameters before calling '
|
||||||
'`module.share_memory()`.')
|
'`module.share_memory()`.')
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return 'Uninitialized parameter'
|
return f'<{self.__class__.__name__}>'
|
||||||
|
|
||||||
def __reduce_ex__(self, proto):
|
def __reduce_ex__(self, proto):
|
||||||
# See Note [Don't serialize hooks]
|
# See Note [Don't serialize hooks]
|
||||||
return (
|
return (
|
||||||
UninitializedParameter,
|
self.__class__,
|
||||||
(self.requires_grad,)
|
(self.requires_grad,)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
# method-wrapper is to detect access to Tensor properties that are
|
# method-wrapper is to detect access to Tensor properties that are
|
||||||
# wrapped in descriptors
|
# wrapped in descriptors
|
||||||
if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper':
|
if func in cls._allowed_methods or func.__class__.__name__ == 'method-wrapper':
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
|
|
@ -135,6 +121,50 @@ class UninitializedParameter(Parameter):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Attempted to use an uninitialized parameter in {}. '
|
'Attempted to use an uninitialized parameter in {}. '
|
||||||
'This error happens when you are using a `LazyModule` or '
|
'This error happens when you are using a `LazyModule` or '
|
||||||
'explicitly manipulating `torch.nn.parameter.UninitializedParameter` '
|
'explicitly manipulating `torch.nn.parameter.{}` '
|
||||||
'objects. When using LazyModules Call `forward` with a dummy batch '
|
'objects. When using LazyModules Call `forward` with a dummy batch '
|
||||||
'to initialize the parameters before calling torch functions'.format(func))
|
'to initialize the parameters before calling torch functions'.format(func, cls.__name__))
|
||||||
|
|
||||||
|
|
||||||
|
def is_lazy(param):
|
||||||
|
return isinstance(param, UninitializedTensorMixin)
|
||||||
|
|
||||||
|
|
||||||
|
class UninitializedParameter(UninitializedTensorMixin, Parameter):
|
||||||
|
r"""A parameter that is not initialized.
|
||||||
|
|
||||||
|
Unitialized Parameters are a a special case of :class:`torch.nn.Parameter`
|
||||||
|
where the shape of the data is still unknown.
|
||||||
|
|
||||||
|
Unlike a :class:`torch.nn.Parameter`, uninitialized parameters
|
||||||
|
hold no data and attempting to access some properties, like their shape,
|
||||||
|
will throw a runtime error. The only operations that can be performed on a uninitialized
|
||||||
|
parameter are changing its datatype, moving it to a different device and
|
||||||
|
converting it to a regular :class:`torch.nn.Parameter`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cls_to_become = Parameter
|
||||||
|
|
||||||
|
def __new__(cls, requires_grad=True):
|
||||||
|
data = torch.Tensor()
|
||||||
|
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
|
||||||
|
r"""A buffer that is not initialized.
|
||||||
|
|
||||||
|
Unitialized Buffer is a a special case of :class:`torch.Tensor`
|
||||||
|
where the shape of the data is still unknown.
|
||||||
|
|
||||||
|
Unlike a :class:`torch.Tensor`, uninitialized parameters
|
||||||
|
hold no data and attempting to access some properties, like their shape,
|
||||||
|
will throw a runtime error. The only operations that can be performed on a uninitialized
|
||||||
|
parameter are changing its datatype, moving it to a different device and
|
||||||
|
converting it to a regular :class:`torch.Tensor`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cls_to_become = torch.Tensor
|
||||||
|
|
||||||
|
def __new__(cls, requires_grad=False):
|
||||||
|
data = torch.Tensor()
|
||||||
|
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,16 @@ class Parameter(Tensor):
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def is_lazy(param: Tensor): ...
|
||||||
|
|
||||||
class UninitializedParameter(Tensor):
|
class UninitializedParameter(Tensor):
|
||||||
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
|
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
|
||||||
|
|
||||||
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
|
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
|
||||||
...
|
...
|
||||||
|
|
||||||
|
class UninitializedBuffer(Tensor):
|
||||||
|
def __init__(self, data: Tensor=..., requires_grad: builtins.bool=...): ...
|
||||||
|
|
||||||
|
def materialize(self, shape: Tuple[int, ...], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): ...
|
||||||
|
...
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user