mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves #126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
155 lines
5.6 KiB
Python
155 lines
5.6 KiB
Python
r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
|
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
|
from torch import _weight_norm, norm_except_dim
|
|
from typing import Any, TypeVar
|
|
from typing_extensions import deprecated
|
|
from ..modules import Module
|
|
|
|
__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm']
|
|
|
|
class WeightNorm:
|
|
name: str
|
|
dim: int
|
|
|
|
def __init__(self, name: str, dim: int) -> None:
|
|
if dim is None:
|
|
dim = -1
|
|
self.name = name
|
|
self.dim = dim
|
|
|
|
# TODO Make return type more specific
|
|
def compute_weight(self, module: Module) -> Any:
|
|
g = getattr(module, self.name + '_g')
|
|
v = getattr(module, self.name + '_v')
|
|
return _weight_norm(v, g, self.dim)
|
|
|
|
@staticmethod
|
|
@deprecated(
|
|
"`torch.nn.utils.weight_norm` is deprecated "
|
|
"in favor of `torch.nn.utils.parametrizations.weight_norm`.",
|
|
category=FutureWarning,
|
|
)
|
|
def apply(module, name: str, dim: int) -> 'WeightNorm':
|
|
for hook in module._forward_pre_hooks.values():
|
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
|
raise RuntimeError(f"Cannot register two weight_norm hooks on the same parameter {name}")
|
|
|
|
if dim is None:
|
|
dim = -1
|
|
|
|
fn = WeightNorm(name, dim)
|
|
|
|
weight = getattr(module, name)
|
|
if isinstance(weight, UninitializedParameter):
|
|
raise ValueError(
|
|
'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
|
|
'Make sure to run the dummy forward before applying weight normalization')
|
|
# remove w from parameter list
|
|
del module._parameters[name]
|
|
|
|
# add g and v as new parameters and express w as g/||v|| * v
|
|
module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
|
|
module.register_parameter(name + '_v', Parameter(weight.data))
|
|
setattr(module, name, fn.compute_weight(module))
|
|
|
|
# recompute weight before every forward()
|
|
module.register_forward_pre_hook(fn)
|
|
|
|
return fn
|
|
|
|
def remove(self, module: Module) -> None:
|
|
weight = self.compute_weight(module)
|
|
delattr(module, self.name)
|
|
del module._parameters[self.name + '_g']
|
|
del module._parameters[self.name + '_v']
|
|
setattr(module, self.name, Parameter(weight.data))
|
|
|
|
def __call__(self, module: Module, inputs: Any) -> None:
|
|
setattr(module, self.name, self.compute_weight(module))
|
|
|
|
|
|
T_module = TypeVar('T_module', bound=Module)
|
|
|
|
def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module:
|
|
r"""Apply weight normalization to a parameter in the given module.
|
|
|
|
.. math::
|
|
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
|
|
|
|
Weight normalization is a reparameterization that decouples the magnitude
|
|
of a weight tensor from its direction. This replaces the parameter specified
|
|
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
|
|
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
|
|
Weight normalization is implemented via a hook that recomputes the weight
|
|
tensor from the magnitude and direction before every :meth:`~Module.forward`
|
|
call.
|
|
|
|
By default, with ``dim=0``, the norm is computed independently per output
|
|
channel/plane. To compute a norm over the entire weight tensor, use
|
|
``dim=None``.
|
|
|
|
See https://arxiv.org/abs/1602.07868
|
|
|
|
.. warning::
|
|
|
|
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
|
|
which uses the modern parametrization API. The new ``weight_norm`` is compatible
|
|
with ``state_dict`` generated from old ``weight_norm``.
|
|
|
|
Migration guide:
|
|
|
|
* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed
|
|
as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1``
|
|
respectively. If this is bothering you, please comment on
|
|
https://github.com/pytorch/pytorch/issues/102999
|
|
|
|
* To remove the weight normalization reparametrization, use
|
|
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
|
|
|
|
* The weight is no longer recomputed once at module forward; instead, it will
|
|
be recomputed on every access. To restore the old behavior, use
|
|
:func:`torch.nn.utils.parametrize.cached` before invoking the module
|
|
in question.
|
|
|
|
Args:
|
|
module (Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
dim (int, optional): dimension over which to compute the norm
|
|
|
|
Returns:
|
|
The original module with the weight norm hook
|
|
|
|
Example::
|
|
|
|
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
|
|
>>> m
|
|
Linear(in_features=20, out_features=40, bias=True)
|
|
>>> m.weight_g.size()
|
|
torch.Size([40, 1])
|
|
>>> m.weight_v.size()
|
|
torch.Size([40, 20])
|
|
|
|
"""
|
|
WeightNorm.apply(module, name, dim)
|
|
return module
|
|
|
|
|
|
def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
|
|
r"""Remove the weight normalization reparameterization from a module.
|
|
|
|
Args:
|
|
module (Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
|
|
Example:
|
|
>>> m = weight_norm(nn.Linear(20, 40))
|
|
>>> remove_weight_norm(m)
|
|
"""
|
|
for k, hook in module._forward_pre_hooks.items():
|
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
|
hook.remove(module)
|
|
del module._forward_pre_hooks[k]
|
|
return module
|
|
|
|
raise ValueError(f"weight_norm of '{name}' not found in {module}")
|