pytorch/torch/nn/utils/weight_norm.py
Senthil Kumar N 3f62531191 Fix: docstring errors in torch.nn.utils - parametrizations.py/prune.py/weight_norm.py (#113021)
Fixes #112631. As the previous PR #112943 has some accidental merge and it resolved through this PR.

- torch/nn/utils/parametrizations.py
**Before - 6**
```
torch\nn\utils\parametrizations.py:1 at module level:
        D100: Missing docstring in public module
torch\nn\utils\parametrizations.py:23 in private function `_make_orthogonal`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\parametrizations.py:23 in private function `_make_orthogonal`:
        D210: No whitespaces allowed surrounding docstring text
torch\nn\utils\parametrizations.py:178 in public function `orthogonal`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\parametrizations.py:309 in public function `weight_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\parametrizations.py:483 in public function `spectral_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
6
```
**After - 1**
```
torch\nn\utils\parametrizations.py:1 at module level:
        D100: Missing docstring in public module
1
```
- torch/nn/utils/prune.py
**Before - 100**
```
torch\nn\utils\prune.py:1 at module level:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch\nn\utils\prune.py:1 at module level:
        D400: First line should end with a period (not 's')
torch\nn\utils\prune.py:13 in public class `BasePruningMethod`:
        D204: 1 blank line required after class docstring (found 0)
torch\nn\utils\prune.py:21 in public method `__call__`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:21 in public method `__call__`:
        D400: First line should end with a period (not ')')
torch\nn\utils\prune.py:34 in public method `compute_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:34 in public method `compute_mask`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch\nn\utils\prune.py:53 in public method `apply_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:53 in public method `apply_mask`:
        D400: First line should end with a period (not 'g')
torch\nn\utils\prune.py:74 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:74 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:74 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:200 in public method `prune`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:200 in public method `prune`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:200 in public method `prune`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch\nn\utils\prune.py:229 in public method `remove`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:229 in public method `remove`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:229 in public method `remove`:
        D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes')
torch\nn\utils\prune.py:256 in public class `PruningContainer`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:264 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:277 in public method `add_pruning_method`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:297 in public method `__len__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:300 in public method `__iter__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:303 in public method `__getitem__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:307 in public method `compute_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:307 in public method `compute_mask`:
        D400: First line should end with a period (not 's')
torch\nn\utils\prune.py:307 in public method `compute_mask`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\prune.py:335 in private nested function `_combine_masks`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:335 in private nested function `_combine_masks`:
        D400: First line should end with a period (not ':')
torch\nn\utils\prune.py:404 in public class `Identity`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:404 in public class `Identity`:
        D400: First line should end with a period (not 'e')
torch\nn\utils\prune.py:410 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:416 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:416 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:416 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:442 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:447 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:469 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:469 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:469 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:486 in public class `L1Unstructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:486 in public class `L1Unstructured`:
        D400: First line should end with a period (not 's')
torch\nn\utils\prune.py:498 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:503 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:527 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:527 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:527 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:564 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:571 in public method `compute_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:571 in public method `compute_mask`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch\nn\utils\prune.py:634 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:634 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:634 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:653 in public class `LnStructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:653 in public class `LnStructured`:
        D400: First line should end with a period (not 'r')
torch\nn\utils\prune.py:669 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:677 in public method `compute_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:677 in public method `compute_mask`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch\nn\utils\prune.py:747 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:747 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:747 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:779 in public class `CustomFromMask`:
        D101: Missing docstring in public class
torch\nn\utils\prune.py:783 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:786 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:793 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:793 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:793 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:806 in public function `identity`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:806 in public function `identity`:
        D400: First line should end with a period (not 'e')
torch\nn\utils\prune.py:806 in public function `identity`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\prune.py:839 in public function `random_unstructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:839 in public function `random_unstructured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:874 in public function `l1_unstructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:874 in public function `l1_unstructured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:916 in public function `random_structured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:916 in public function `random_structured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:955 in public function `ln_structured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:955 in public function `ln_structured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:1000 in public function `global_unstructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1000 in public function `global_unstructured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:1120 in public function `custom_from_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1120 in public function `custom_from_mask`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:1154 in public function `remove`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1154 in public function `remove`:
        D400: First line should end with a period (not 'e')
torch\nn\utils\prune.py:1154 in public function `remove`:
        D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes')
torch\nn\utils\prune.py:1184 in public function `is_pruned`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1184 in public function `is_pruned`:
        D400: First line should end with a period (not 'r')
torch\nn\utils\prune.py:1211 in private function `_validate_pruning_amount_init`:
        D401: First line should be in imperative mood (perhaps 'Validate', not 'Validation')
torch\nn\utils\prune.py:1243 in private function `_validate_pruning_amount`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1243 in private function `_validate_pruning_amount`:
        D400: First line should end with a period (not 'e')
torch\nn\utils\prune.py:1243 in private function `_validate_pruning_amount`:
        D401: First line should be in imperative mood (perhaps 'Validate', not 'Validation')
torch\nn\utils\prune.py:1265 in private function `_validate_structured_pruning`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1265 in private function `_validate_structured_pruning`:
        D400: First line should end with a period (not '-')
torch\nn\utils\prune.py:1265 in private function `_validate_structured_pruning`:
        D401: First line should be in imperative mood (perhaps 'Validate', not 'Validation')
torch\nn\utils\prune.py:1284 in private function `_compute_nparams_toprune`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1284 in private function `_compute_nparams_toprune`:
        D400: First line should end with a period (not 'a')
torch\nn\utils\prune.py:1308 in private function `_validate_pruning_dim`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1308 in private function `_validate_pruning_dim`:
        D400: First line should end with a period (not ':')
torch\nn\utils\prune.py:1318 in private function `_compute_norm`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1318 in private function `_compute_norm`:
        D400: First line should end with a period (not 'n')
100
```
**After - 14**
```
torch\nn\utils\prune.py:266 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:299 in public method `__len__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:302 in public method `__iter__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:305 in public method `__getitem__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:411 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:445 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:450 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:502 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:507 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:570 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:677 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:790 in public class `CustomFromMask`:
        D101: Missing docstring in public class
torch\nn\utils\prune.py:794 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:797 in public method `compute_mask`:
        D102: Missing docstring in public method
14
```
- torch/nn/utils/weight_norm.py
**Before - 10**
```
torch\nn\utils\weight_norm.py:1 at module level:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch\nn\utils\weight_norm.py:1 at module level:
        D400: First line should end with a period (not '8')
torch\nn\utils\weight_norm.py:12 in public class `WeightNorm`:
        D101: Missing docstring in public class
torch\nn\utils\weight_norm.py:16 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\weight_norm.py:23 in public method `compute_weight`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:29 in public method `apply`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:59 in public method `remove`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:66 in public method `__call__`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:73 in public function `weight_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\weight_norm.py:137 in public function `remove_weight_norm`:
        D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes')
10
```
**After - 6**
```
torch\nn\utils\weight_norm.py:10 in public class `WeightNorm`:
        D101: Missing docstring in public class
torch\nn\utils\weight_norm.py:14 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\weight_norm.py:21 in public method `compute_weight`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:27 in public method `apply`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:57 in public method `remove`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:64 in public method `__call__`:
        D102: Missing docstring in public method
6
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113021
Approved by: https://github.com/lezcano
2023-11-06 17:24:32 +00:00

152 lines
5.6 KiB
Python

r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
from torch.nn.parameter import Parameter, UninitializedParameter
from torch import _weight_norm, norm_except_dim
from typing import Any, TypeVar
import warnings
from ..modules import Module
__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm']
class WeightNorm:
name: str
dim: int
def __init__(self, name: str, dim: int) -> None:
if dim is None:
dim = -1
self.name = name
self.dim = dim
# TODO Make return type more specific
def compute_weight(self, module: Module) -> Any:
g = getattr(module, self.name + '_g')
v = getattr(module, self.name + '_v')
return _weight_norm(v, g, self.dim)
@staticmethod
def apply(module, name: str, dim: int) -> 'WeightNorm':
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
for hook in module._forward_pre_hooks.values():
if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError(f"Cannot register two weight_norm hooks on the same parameter {name}")
if dim is None:
dim = -1
fn = WeightNorm(name, dim)
weight = getattr(module, name)
if isinstance(weight, UninitializedParameter):
raise ValueError(
'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
'Make sure to run the dummy forward before applying weight normalization')
# remove w from parameter list
del module._parameters[name]
# add g and v as new parameters and express w as g/||v|| * v
module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
module.register_parameter(name + '_v', Parameter(weight.data))
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module: Module) -> None:
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + '_g']
del module._parameters[self.name + '_v']
setattr(module, self.name, Parameter(weight.data))
def __call__(self, module: Module, inputs: Any) -> None:
setattr(module, self.name, self.compute_weight(module))
T_module = TypeVar('T_module', bound=Module)
def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module:
r"""Apply weight normalization to a parameter in the given module.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.
By default, with ``dim=0``, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
``dim=None``.
See https://arxiv.org/abs/1602.07868
.. warning::
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
which uses the modern parametrization API. The new ``weight_norm`` is compatible
with ``state_dict`` generated from old ``weight_norm``.
Migration guide:
* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed
as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1``
respectively. If this is bothering you, please comment on
https://github.com/pytorch/pytorch/issues/102999
* To remove the weight normalization reparametrization, use
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
* The weight is no longer recomputed once at module forward; instead, it will
be recomputed on every access. To restore the old behavior, use
:func:`torch.nn.utils.parametrize.cached` before invoking the module
in question.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
Returns:
The original module with the weight norm hook
Example::
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])
"""
WeightNorm.apply(module, name, dim)
return module
def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
r"""Remove the weight normalization reparameterization from a module.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError(f"weight_norm of '{name}' not found in {module}")