mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* initial commit for spectral norm * fix comment * edit rst * fix doc * remove redundant empty line * fix nit mistakes in doc * replace l2normalize with F.normalize * fix chained `by` * fix docs fix typos add comments related to power iteration and epsilon update link to the paper make some comments specific * fix typo
119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
"""
|
|
Spectral Normalization from https://arxiv.org/abs/1802.05957
|
|
"""
|
|
import torch
|
|
from torch.nn.functional import normalize
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
class SpectralNorm(object):
|
|
|
|
def __init__(self, name='weight', n_power_iterations=1, eps=1e-12):
|
|
self.name = name
|
|
self.n_power_iterations = n_power_iterations
|
|
self.eps = eps
|
|
|
|
def compute_weight(self, module):
|
|
weight = module._parameters[self.name + '_org']
|
|
u = module._buffers[self.name + '_u']
|
|
height = weight.size(0)
|
|
weight_mat = weight.view(height, -1)
|
|
for _ in range(self.n_power_iterations):
|
|
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
|
# are the first left and right singular vectors.
|
|
# This power iteration produces approximations of `u` and `v`.
|
|
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
|
|
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
|
|
|
|
sigma = torch.dot(u, torch.matmul(weight_mat, v))
|
|
weight.data /= sigma
|
|
return weight, u
|
|
|
|
def remove(self, module):
|
|
weight = module._parameters[self.name + '_org']
|
|
del module._parameters[self.name]
|
|
del module._buffers[self.name + '_u']
|
|
del module._parameters[self.name + '_org']
|
|
module.register_parameter(self.name, weight)
|
|
|
|
def __call__(self, module, inputs):
|
|
weight, u = self.compute_weight(module)
|
|
setattr(module, self.name, weight)
|
|
setattr(module, self.name + '_u', u)
|
|
|
|
@staticmethod
|
|
def apply(module, name, n_power_iterations, eps):
|
|
fn = SpectralNorm(name, n_power_iterations, eps)
|
|
weight = module._parameters[name]
|
|
height = weight.size(0)
|
|
|
|
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
|
|
module.register_parameter(fn.name + "_org", weight)
|
|
module.register_buffer(fn.name + "_u", u)
|
|
|
|
module.register_forward_pre_hook(fn)
|
|
return fn
|
|
|
|
|
|
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12):
|
|
r"""Applies spectral normalization to a parameter in the given module.
|
|
|
|
.. math::
|
|
\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
|
|
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
|
|
|
Spectral normalization stabilizes the training of discriminators (critics)
|
|
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
|
|
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
|
power iteration method. If the dimension of the weight tensor is greater
|
|
than 2, it is reshaped to 2D in power iteration method to get spectral
|
|
norm. This is implemented via a hook that calculates spectral norm and
|
|
rescales weight before every :meth:`~Module.forward` call.
|
|
|
|
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
|
|
|
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
|
|
|
Args:
|
|
module (nn.Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
n_power_iterations (int, optional): number of power iterations to
|
|
calculate spectal norm
|
|
eps (float, optional): epsilon for numerical stability in
|
|
calculating norms
|
|
|
|
Returns:
|
|
The original module with the spectal norm hook
|
|
|
|
Example::
|
|
|
|
>>> m = spectral_norm(nn.Linear(20, 40))
|
|
Linear (20 -> 40)
|
|
>>> m.weight_u.size()
|
|
torch.Size([20])
|
|
|
|
"""
|
|
SpectralNorm.apply(module, name, n_power_iterations, eps)
|
|
return module
|
|
|
|
|
|
def remove_spectral_norm(module, name='weight'):
|
|
r"""Removes the spectral normalization reparameterization from a module.
|
|
|
|
Args:
|
|
module (nn.Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
|
|
Example:
|
|
>>> m = spectral_norm(nn.Linear(40, 10))
|
|
>>> remove_spectral_norm(m)
|
|
"""
|
|
for k, hook in module._forward_pre_hooks.items():
|
|
if isinstance(hook, SpectralNorm) and hook.name == name:
|
|
hook.remove(module)
|
|
del module._forward_pre_hooks[k]
|
|
return module
|
|
|
|
raise ValueError("spectral_norm of '{}' not found in {}".format(
|
|
name, module))
|