mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: During development it is common practice to put `type: ignore` comments on lines that are correct, but `mypy` doesn't recognize this. This often stems from the fact, that the used `mypy` version wasn't able to handle the used pattern. With every new release `mypy` gets better at handling complex code. In addition to fix all the previously accepted but now failing patterns, we should also revisit all `type: ignore` comments to see if they are still needed or not. Fortunately, we don't need to do it manually: by adding `warn_unused_ignores = True` to the configuration, `mypy` will error out in case it encounters an `type: ignore` that is no longer needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/60006 Reviewed By: jbschlosser, malfet Differential Revision: D29133237 Pulled By: albanD fbshipit-source-id: 41e82edc5cd5affa7ccedad044b59b94dad4425a
212 lines
9.5 KiB
Python
212 lines
9.5 KiB
Python
import torch
|
|
from ..utils import parametrize
|
|
from ..modules import Module
|
|
from .. import functional as F
|
|
|
|
from typing import Optional
|
|
|
|
class _SpectralNorm(Module):
|
|
def __init__(
|
|
self,
|
|
weight: torch.Tensor,
|
|
n_power_iterations: int = 1,
|
|
dim: int = 0,
|
|
eps: float = 1e-12
|
|
) -> None:
|
|
super().__init__()
|
|
ndim = weight.ndim
|
|
if dim >= ndim or dim < -ndim:
|
|
raise IndexError("Dimension out of range (expected to be in range of "
|
|
f"[-{ndim}, {ndim - 1}] but got {dim})")
|
|
|
|
if n_power_iterations <= 0:
|
|
raise ValueError('Expected n_power_iterations to be positive, but '
|
|
'got n_power_iterations={}'.format(n_power_iterations))
|
|
self.dim = dim if dim >= 0 else dim + ndim
|
|
self.eps = eps
|
|
if ndim > 1:
|
|
# For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
|
|
self.n_power_iterations = n_power_iterations
|
|
weight_mat = self._reshape_weight_to_matrix(weight)
|
|
h, w = weight_mat.size()
|
|
|
|
u = weight_mat.new_empty(h).normal_(0, 1)
|
|
v = weight_mat.new_empty(w).normal_(0, 1)
|
|
self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps))
|
|
self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps))
|
|
|
|
# Start with u, v initialized to some reasonable values by performing a number
|
|
# of iterations of the power method
|
|
self._power_method(weight_mat, 15)
|
|
|
|
def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
|
|
# Precondition
|
|
assert weight.ndim > 1
|
|
|
|
if self.dim != 0:
|
|
# permute dim to front
|
|
weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim))
|
|
|
|
return weight.flatten(1)
|
|
|
|
@torch.autograd.no_grad()
|
|
def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
|
|
# See original note at torch/nn/utils/spectral_norm.py
|
|
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
|
|
# updated in power iteration **in-place**. This is very important
|
|
# because in `DataParallel` forward, the vectors (being buffers) are
|
|
# broadcast from the parallelized module to each module replica,
|
|
# which is a new module object created on the fly. And each replica
|
|
# runs its own spectral norm power iteration. So simply assigning
|
|
# the updated vectors to the module this function runs on will cause
|
|
# the update to be lost forever. And the next time the parallelized
|
|
# module is replicated, the same randomly initialized vectors are
|
|
# broadcast and used!
|
|
#
|
|
# Therefore, to make the change propagate back, we rely on two
|
|
# important behaviors (also enforced via tests):
|
|
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
|
|
# is already on correct device; and it makes sure that the
|
|
# parallelized module is already on `device[0]`.
|
|
# 2. If the out tensor in `out=` kwarg has correct shape, it will
|
|
# just fill in the values.
|
|
# Therefore, since the same power iteration is performed on all
|
|
# devices, simply updating the tensors in-place will make sure that
|
|
# the module replica on `device[0]` will update the _u vector on the
|
|
# parallized module (by shared storage).
|
|
#
|
|
# However, after we update `u` and `v` in-place, we need to **clone**
|
|
# them before using them to normalize the weight. This is to support
|
|
# backproping through two forward passes, e.g., the common pattern in
|
|
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
|
|
# complain that variables needed to do backward for the first forward
|
|
# (i.e., the `u` and `v` vectors) are changed in the second forward.
|
|
|
|
# Precondition
|
|
assert weight_mat.ndim > 1
|
|
for _ in range(n_power_iterations):
|
|
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
|
# are the first left and right singular vectors.
|
|
# This power iteration produces approximations of `u` and `v`.
|
|
self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type]
|
|
dim=0, eps=self.eps, out=self._u) # type: ignore[has-type]
|
|
self._v = F.normalize(torch.mv(weight_mat.t(), self._u),
|
|
dim=0, eps=self.eps, out=self._v) # type: ignore[has-type]
|
|
# See above on why we need to clone
|
|
self._u = self._u.clone(memory_format=torch.contiguous_format)
|
|
self._v = self._v.clone(memory_format=torch.contiguous_format)
|
|
|
|
def forward(self, weight: torch.Tensor) -> torch.Tensor:
|
|
if weight.ndim == 1:
|
|
# Faster and more exact path, no need to approximate anything
|
|
return F.normalize(weight, dim=0, eps=self.eps)
|
|
else:
|
|
weight_mat = self._reshape_weight_to_matrix(weight)
|
|
if self.training:
|
|
self._power_method(weight_mat, self.n_power_iterations)
|
|
# The proper way of computing this should be through F.bilinear, but
|
|
# it seems to have some efficiency issues:
|
|
# https://github.com/pytorch/pytorch/issues/58093
|
|
sigma = torch.dot(self._u, torch.mv(weight_mat, self._v))
|
|
return weight / sigma
|
|
|
|
def right_inverse(self, value: torch.Tensor) -> torch.Tensor:
|
|
# we may want to assert here that the passed value already
|
|
# satisfies constraints
|
|
return value
|
|
|
|
|
|
def spectral_norm(module: Module,
|
|
name: str = 'weight',
|
|
n_power_iterations: int = 1,
|
|
eps: float = 1e-12,
|
|
dim: Optional[int] = None) -> Module:
|
|
r"""Applies spectral normalization to a parameter in the given module.
|
|
|
|
.. math::
|
|
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
|
|
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
|
|
|
When applied on a vector, it simplifies to
|
|
|
|
.. math::
|
|
\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}
|
|
|
|
Spectral normalization stabilizes the training of discriminators (critics)
|
|
in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant
|
|
of the model. :math:`\sigma` is approximated performing one iteration of the
|
|
`power method`_ every time the weight is accessed. If the dimension of the
|
|
weight tensor is greater than 2, it is reshaped to 2D in power iteration
|
|
method to get spectral norm.
|
|
|
|
|
|
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
|
|
|
.. _`power method`: https://en.wikipedia.org/wiki/Power_iteration
|
|
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
|
|
|
.. note::
|
|
This function is implemented using the new parametrization functionality
|
|
in :func:`torch.nn.utils.parametrize.register_parametrization`. It is a
|
|
reimplementation of :func:`torch.nn.utils.spectral_norm`.
|
|
|
|
.. note::
|
|
When this constraint is registered, the singular vectors associated to the largest
|
|
singular value are estimated rather than sampled at random. These are then updated
|
|
performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor
|
|
is accessed with the module on `training` mode.
|
|
|
|
.. note::
|
|
If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`,
|
|
is in training mode on removal, it will perform another power iteration.
|
|
If you'd like to avoid this iteration, set the module to eval mode
|
|
before its removal.
|
|
|
|
Args:
|
|
module (nn.Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
n_power_iterations (int, optional): number of power iterations to
|
|
calculate spectral norm
|
|
eps (float, optional): epsilon for numerical stability in
|
|
calculating norms
|
|
dim (int, optional): dimension corresponding to number of outputs,
|
|
the default is ``0``, except for modules that are instances of
|
|
ConvTranspose{1,2,3}d, when it is ``1``
|
|
|
|
Returns:
|
|
The original module with a new parametrization registered to the specified
|
|
weight
|
|
|
|
Example::
|
|
|
|
>>> snm = spectral_norm(nn.Linear(20, 40))
|
|
>>> snm
|
|
ParametrizedLinear(
|
|
in_features=20, out_features=40, bias=True
|
|
(parametrizations): ModuleDict(
|
|
(weight): ParametrizationList(
|
|
(0): _SpectralNorm()
|
|
)
|
|
)
|
|
)
|
|
>>> torch.linalg.matrix_norm(snm.weight, 2)
|
|
tensor(1.0000, grad_fn=<CopyBackwards>)
|
|
"""
|
|
if not hasattr(module, name):
|
|
raise ValueError(
|
|
"Module '{}' has no attribute with name '{}'".format(module, name)
|
|
)
|
|
# getattr should get the correct parametrized weight if there
|
|
# is already an parametrization registered
|
|
weight = getattr(module, name)
|
|
|
|
if dim is None:
|
|
if isinstance(module, (torch.nn.ConvTranspose1d,
|
|
torch.nn.ConvTranspose2d,
|
|
torch.nn.ConvTranspose3d)):
|
|
dim = 1
|
|
else:
|
|
dim = 0
|
|
parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps))
|
|
return module
|