pytorch/torch/nn/utils/spectral_norm.py
Masaki Kozuki ba046331e8 add spectral normalization [pytorch] (#6929)
* initial commit for spectral norm

* fix comment

* edit rst

* fix doc

* remove redundant empty line

* fix nit mistakes in doc

* replace l2normalize with F.normalize

* fix chained `by`

* fix docs

fix typos
add comments related to power iteration and epsilon
update link to the paper
make some comments specific

* fix typo
2018-05-01 17:00:30 +08:00

119 lines
4.3 KiB
Python

"""
Spectral Normalization from https://arxiv.org/abs/1802.05957
"""
import torch
from torch.nn.functional import normalize
from torch.nn.parameter import Parameter
class SpectralNorm(object):
def __init__(self, name='weight', n_power_iterations=1, eps=1e-12):
self.name = name
self.n_power_iterations = n_power_iterations
self.eps = eps
def compute_weight(self, module):
weight = module._parameters[self.name + '_org']
u = module._buffers[self.name + '_u']
height = weight.size(0)
weight_mat = weight.view(height, -1)
for _ in range(self.n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
sigma = torch.dot(u, torch.matmul(weight_mat, v))
weight.data /= sigma
return weight, u
def remove(self, module):
weight = module._parameters[self.name + '_org']
del module._parameters[self.name]
del module._buffers[self.name + '_u']
del module._parameters[self.name + '_org']
module.register_parameter(self.name, weight)
def __call__(self, module, inputs):
weight, u = self.compute_weight(module)
setattr(module, self.name, weight)
setattr(module, self.name + '_u', u)
@staticmethod
def apply(module, name, n_power_iterations, eps):
fn = SpectralNorm(name, n_power_iterations, eps)
weight = module._parameters[name]
height = weight.size(0)
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
module.register_parameter(fn.name + "_org", weight)
module.register_buffer(fn.name + "_u", u)
module.register_forward_pre_hook(fn)
return fn
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12):
r"""Applies spectral normalization to a parameter in the given module.
.. math::
\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
with spectral norm :math:`\sigma` of the weight matrix calculated using
power iteration method. If the dimension of the weight tensor is greater
than 2, it is reshaped to 2D in power iteration method to get spectral
norm. This is implemented via a hook that calculates spectral norm and
rescales weight before every :meth:`~Module.forward` call.
See `Spectral Normalization for Generative Adversarial Networks`_ .
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
n_power_iterations (int, optional): number of power iterations to
calculate spectal norm
eps (float, optional): epsilon for numerical stability in
calculating norms
Returns:
The original module with the spectal norm hook
Example::
>>> m = spectral_norm(nn.Linear(20, 40))
Linear (20 -> 40)
>>> m.weight_u.size()
torch.Size([20])
"""
SpectralNorm.apply(module, name, n_power_iterations, eps)
return module
def remove_spectral_norm(module, name='weight'):
r"""Removes the spectral normalization reparameterization from a module.
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = spectral_norm(nn.Linear(40, 10))
>>> remove_spectral_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))