mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17351 Differential Revision: D14276355 Pulled By: soumith fbshipit-source-id: 9b572b6a04eeb1e44cd93961edac76ed10f7b24e
120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
r"""
|
|
Weight Normalization from https://arxiv.org/abs/1602.07868
|
|
"""
|
|
from torch.nn.parameter import Parameter
|
|
from torch import _weight_norm, norm_except_dim
|
|
|
|
|
|
class WeightNorm(object):
|
|
def __init__(self, name, dim):
|
|
if dim is None:
|
|
dim = -1
|
|
self.name = name
|
|
self.dim = dim
|
|
|
|
def compute_weight(self, module):
|
|
g = getattr(module, self.name + '_g')
|
|
v = getattr(module, self.name + '_v')
|
|
return _weight_norm(v, g, self.dim)
|
|
|
|
@staticmethod
|
|
def apply(module, name, dim):
|
|
for k, hook in module._forward_pre_hooks.items():
|
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
|
raise RuntimeError("Cannot register two weight_norm hooks on "
|
|
"the same parameter {}".format(name))
|
|
|
|
if dim is None:
|
|
dim = -1
|
|
|
|
fn = WeightNorm(name, dim)
|
|
|
|
weight = getattr(module, name)
|
|
|
|
# remove w from parameter list
|
|
del module._parameters[name]
|
|
|
|
# add g and v as new parameters and express w as g/||v|| * v
|
|
module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
|
|
module.register_parameter(name + '_v', Parameter(weight.data))
|
|
setattr(module, name, fn.compute_weight(module))
|
|
|
|
# recompute weight before every forward()
|
|
module.register_forward_pre_hook(fn)
|
|
|
|
return fn
|
|
|
|
def remove(self, module):
|
|
weight = self.compute_weight(module)
|
|
delattr(module, self.name)
|
|
del module._parameters[self.name + '_g']
|
|
del module._parameters[self.name + '_v']
|
|
module.register_parameter(self.name, Parameter(weight.data))
|
|
|
|
def __call__(self, module, inputs):
|
|
setattr(module, self.name, self.compute_weight(module))
|
|
|
|
|
|
def weight_norm(module, name='weight', dim=0):
|
|
r"""Applies weight normalization to a parameter in the given module.
|
|
|
|
.. math::
|
|
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
|
|
|
|
Weight normalization is a reparameterization that decouples the magnitude
|
|
of a weight tensor from its direction. This replaces the parameter specified
|
|
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
|
|
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
|
|
Weight normalization is implemented via a hook that recomputes the weight
|
|
tensor from the magnitude and direction before every :meth:`~Module.forward`
|
|
call.
|
|
|
|
By default, with ``dim=0``, the norm is computed independently per output
|
|
channel/plane. To compute a norm over the entire weight tensor, use
|
|
``dim=None``.
|
|
|
|
See https://arxiv.org/abs/1602.07868
|
|
|
|
Args:
|
|
module (Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
dim (int, optional): dimension over which to compute the norm
|
|
|
|
Returns:
|
|
The original module with the weight norm hook
|
|
|
|
Example::
|
|
|
|
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
|
|
>>> m
|
|
Linear(in_features=20, out_features=40, bias=True)
|
|
>>> m.weight_g.size()
|
|
torch.Size([40, 1])
|
|
>>> m.weight_v.size()
|
|
torch.Size([40, 20])
|
|
|
|
"""
|
|
WeightNorm.apply(module, name, dim)
|
|
return module
|
|
|
|
|
|
def remove_weight_norm(module, name='weight'):
|
|
r"""Removes the weight normalization reparameterization from a module.
|
|
|
|
Args:
|
|
module (Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
|
|
Example:
|
|
>>> m = weight_norm(nn.Linear(20, 40))
|
|
>>> remove_weight_norm(m)
|
|
"""
|
|
for k, hook in module._forward_pre_hooks.items():
|
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
|
hook.remove(module)
|
|
del module._forward_pre_hooks[k]
|
|
return module
|
|
|
|
raise ValueError("weight_norm of '{}' not found in {}"
|
|
.format(name, module))
|