mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes #112631. As the previous PR #112943 has some accidental merge and it resolved through this PR. - torch/nn/utils/parametrizations.py **Before - 6** ``` torch\nn\utils\parametrizations.py:1 at module level: D100: Missing docstring in public module torch\nn\utils\parametrizations.py:23 in private function `_make_orthogonal`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\parametrizations.py:23 in private function `_make_orthogonal`: D210: No whitespaces allowed surrounding docstring text torch\nn\utils\parametrizations.py:178 in public function `orthogonal`: D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies') torch\nn\utils\parametrizations.py:309 in public function `weight_norm`: D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies') torch\nn\utils\parametrizations.py:483 in public function `spectral_norm`: D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies') 6 ``` **After - 1** ``` torch\nn\utils\parametrizations.py:1 at module level: D100: Missing docstring in public module 1 ``` - torch/nn/utils/prune.py **Before - 100** ``` torch\nn\utils\prune.py:1 at module level: D200: One-line docstring should fit on one line with quotes (found 3) torch\nn\utils\prune.py:1 at module level: D400: First line should end with a period (not 's') torch\nn\utils\prune.py:13 in public class `BasePruningMethod`: D204: 1 blank line required after class docstring (found 0) torch\nn\utils\prune.py:21 in public method `__call__`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:21 in public method `__call__`: D400: First line should end with a period (not ')') torch\nn\utils\prune.py:34 in public method `compute_mask`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:34 in public method `compute_mask`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') torch\nn\utils\prune.py:53 in public method `apply_mask`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:53 in public method `apply_mask`: D400: First line should end with a period (not 'g') torch\nn\utils\prune.py:74 in public method `apply`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:74 in public method `apply`: D400: First line should end with a period (not 'd') torch\nn\utils\prune.py:74 in public method `apply`: D401: First line should be in imperative mood (perhaps 'Add', not 'Adds') torch\nn\utils\prune.py:200 in public method `prune`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:200 in public method `prune`: D400: First line should end with a period (not '`') torch\nn\utils\prune.py:200 in public method `prune`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') torch\nn\utils\prune.py:229 in public method `remove`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:229 in public method `remove`: D400: First line should end with a period (not 'd') torch\nn\utils\prune.py:229 in public method `remove`: D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes') torch\nn\utils\prune.py:256 in public class `PruningContainer`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:264 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:277 in public method `add_pruning_method`: D401: First line should be in imperative mood (perhaps 'Add', not 'Adds') torch\nn\utils\prune.py:297 in public method `__len__`: D105: Missing docstring in magic method torch\nn\utils\prune.py:300 in public method `__iter__`: D105: Missing docstring in magic method torch\nn\utils\prune.py:303 in public method `__getitem__`: D105: Missing docstring in magic method torch\nn\utils\prune.py:307 in public method `compute_mask`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:307 in public method `compute_mask`: D400: First line should end with a period (not 's') torch\nn\utils\prune.py:307 in public method `compute_mask`: D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies') torch\nn\utils\prune.py:335 in private nested function `_combine_masks`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:335 in private nested function `_combine_masks`: D400: First line should end with a period (not ':') torch\nn\utils\prune.py:404 in public class `Identity`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:404 in public class `Identity`: D400: First line should end with a period (not 'e') torch\nn\utils\prune.py:410 in public method `compute_mask`: D102: Missing docstring in public method torch\nn\utils\prune.py:416 in public method `apply`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:416 in public method `apply`: D400: First line should end with a period (not 'd') torch\nn\utils\prune.py:416 in public method `apply`: D401: First line should be in imperative mood (perhaps 'Add', not 'Adds') torch\nn\utils\prune.py:442 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:447 in public method `compute_mask`: D102: Missing docstring in public method torch\nn\utils\prune.py:469 in public method `apply`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:469 in public method `apply`: D400: First line should end with a period (not 'd') torch\nn\utils\prune.py:469 in public method `apply`: D401: First line should be in imperative mood (perhaps 'Add', not 'Adds') torch\nn\utils\prune.py:486 in public class `L1Unstructured`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:486 in public class `L1Unstructured`: D400: First line should end with a period (not 's') torch\nn\utils\prune.py:498 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:503 in public method `compute_mask`: D102: Missing docstring in public method torch\nn\utils\prune.py:527 in public method `apply`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:527 in public method `apply`: D400: First line should end with a period (not 'd') torch\nn\utils\prune.py:527 in public method `apply`: D401: First line should be in imperative mood (perhaps 'Add', not 'Adds') torch\nn\utils\prune.py:564 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:571 in public method `compute_mask`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:571 in public method `compute_mask`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') torch\nn\utils\prune.py:634 in public method `apply`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:634 in public method `apply`: D400: First line should end with a period (not 'd') torch\nn\utils\prune.py:634 in public method `apply`: D401: First line should be in imperative mood (perhaps 'Add', not 'Adds') torch\nn\utils\prune.py:653 in public class `LnStructured`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:653 in public class `LnStructured`: D400: First line should end with a period (not 'r') torch\nn\utils\prune.py:669 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:677 in public method `compute_mask`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:677 in public method `compute_mask`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') torch\nn\utils\prune.py:747 in public method `apply`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:747 in public method `apply`: D400: First line should end with a period (not 'd') torch\nn\utils\prune.py:747 in public method `apply`: D401: First line should be in imperative mood (perhaps 'Add', not 'Adds') torch\nn\utils\prune.py:779 in public class `CustomFromMask`: D101: Missing docstring in public class torch\nn\utils\prune.py:783 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:786 in public method `compute_mask`: D102: Missing docstring in public method torch\nn\utils\prune.py:793 in public method `apply`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:793 in public method `apply`: D400: First line should end with a period (not 'd') torch\nn\utils\prune.py:793 in public method `apply`: D401: First line should be in imperative mood (perhaps 'Add', not 'Adds') torch\nn\utils\prune.py:806 in public function `identity`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:806 in public function `identity`: D400: First line should end with a period (not 'e') torch\nn\utils\prune.py:806 in public function `identity`: D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies') torch\nn\utils\prune.py:839 in public function `random_unstructured`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:839 in public function `random_unstructured`: D400: First line should end with a period (not '`') torch\nn\utils\prune.py:874 in public function `l1_unstructured`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:874 in public function `l1_unstructured`: D400: First line should end with a period (not '`') torch\nn\utils\prune.py:916 in public function `random_structured`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:916 in public function `random_structured`: D400: First line should end with a period (not '`') torch\nn\utils\prune.py:955 in public function `ln_structured`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:955 in public function `ln_structured`: D400: First line should end with a period (not '`') torch\nn\utils\prune.py:1000 in public function `global_unstructured`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:1000 in public function `global_unstructured`: D400: First line should end with a period (not '`') torch\nn\utils\prune.py:1120 in public function `custom_from_mask`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:1120 in public function `custom_from_mask`: D400: First line should end with a period (not '`') torch\nn\utils\prune.py:1154 in public function `remove`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:1154 in public function `remove`: D400: First line should end with a period (not 'e') torch\nn\utils\prune.py:1154 in public function `remove`: D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes') torch\nn\utils\prune.py:1184 in public function `is_pruned`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:1184 in public function `is_pruned`: D400: First line should end with a period (not 'r') torch\nn\utils\prune.py:1211 in private function `_validate_pruning_amount_init`: D401: First line should be in imperative mood (perhaps 'Validate', not 'Validation') torch\nn\utils\prune.py:1243 in private function `_validate_pruning_amount`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:1243 in private function `_validate_pruning_amount`: D400: First line should end with a period (not 'e') torch\nn\utils\prune.py:1243 in private function `_validate_pruning_amount`: D401: First line should be in imperative mood (perhaps 'Validate', not 'Validation') torch\nn\utils\prune.py:1265 in private function `_validate_structured_pruning`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:1265 in private function `_validate_structured_pruning`: D400: First line should end with a period (not '-') torch\nn\utils\prune.py:1265 in private function `_validate_structured_pruning`: D401: First line should be in imperative mood (perhaps 'Validate', not 'Validation') torch\nn\utils\prune.py:1284 in private function `_compute_nparams_toprune`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:1284 in private function `_compute_nparams_toprune`: D400: First line should end with a period (not 'a') torch\nn\utils\prune.py:1308 in private function `_validate_pruning_dim`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:1308 in private function `_validate_pruning_dim`: D400: First line should end with a period (not ':') torch\nn\utils\prune.py:1318 in private function `_compute_norm`: D205: 1 blank line required between summary line and description (found 0) torch\nn\utils\prune.py:1318 in private function `_compute_norm`: D400: First line should end with a period (not 'n') 100 ``` **After - 14** ``` torch\nn\utils\prune.py:266 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:299 in public method `__len__`: D105: Missing docstring in magic method torch\nn\utils\prune.py:302 in public method `__iter__`: D105: Missing docstring in magic method torch\nn\utils\prune.py:305 in public method `__getitem__`: D105: Missing docstring in magic method torch\nn\utils\prune.py:411 in public method `compute_mask`: D102: Missing docstring in public method torch\nn\utils\prune.py:445 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:450 in public method `compute_mask`: D102: Missing docstring in public method torch\nn\utils\prune.py:502 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:507 in public method `compute_mask`: D102: Missing docstring in public method torch\nn\utils\prune.py:570 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:677 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:790 in public class `CustomFromMask`: D101: Missing docstring in public class torch\nn\utils\prune.py:794 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\prune.py:797 in public method `compute_mask`: D102: Missing docstring in public method 14 ``` - torch/nn/utils/weight_norm.py **Before - 10** ``` torch\nn\utils\weight_norm.py:1 at module level: D200: One-line docstring should fit on one line with quotes (found 3) torch\nn\utils\weight_norm.py:1 at module level: D400: First line should end with a period (not '8') torch\nn\utils\weight_norm.py:12 in public class `WeightNorm`: D101: Missing docstring in public class torch\nn\utils\weight_norm.py:16 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\weight_norm.py:23 in public method `compute_weight`: D102: Missing docstring in public method torch\nn\utils\weight_norm.py:29 in public method `apply`: D102: Missing docstring in public method torch\nn\utils\weight_norm.py:59 in public method `remove`: D102: Missing docstring in public method torch\nn\utils\weight_norm.py:66 in public method `__call__`: D102: Missing docstring in public method torch\nn\utils\weight_norm.py:73 in public function `weight_norm`: D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies') torch\nn\utils\weight_norm.py:137 in public function `remove_weight_norm`: D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes') 10 ``` **After - 6** ``` torch\nn\utils\weight_norm.py:10 in public class `WeightNorm`: D101: Missing docstring in public class torch\nn\utils\weight_norm.py:14 in public method `__init__`: D107: Missing docstring in __init__ torch\nn\utils\weight_norm.py:21 in public method `compute_weight`: D102: Missing docstring in public method torch\nn\utils\weight_norm.py:27 in public method `apply`: D102: Missing docstring in public method torch\nn\utils\weight_norm.py:57 in public method `remove`: D102: Missing docstring in public method torch\nn\utils\weight_norm.py:64 in public method `__call__`: D102: Missing docstring in public method 6 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/113021 Approved by: https://github.com/lezcano
152 lines
5.6 KiB
Python
152 lines
5.6 KiB
Python
r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
|
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
|
from torch import _weight_norm, norm_except_dim
|
|
from typing import Any, TypeVar
|
|
import warnings
|
|
from ..modules import Module
|
|
|
|
__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm']
|
|
|
|
class WeightNorm:
|
|
name: str
|
|
dim: int
|
|
|
|
def __init__(self, name: str, dim: int) -> None:
|
|
if dim is None:
|
|
dim = -1
|
|
self.name = name
|
|
self.dim = dim
|
|
|
|
# TODO Make return type more specific
|
|
def compute_weight(self, module: Module) -> Any:
|
|
g = getattr(module, self.name + '_g')
|
|
v = getattr(module, self.name + '_v')
|
|
return _weight_norm(v, g, self.dim)
|
|
|
|
@staticmethod
|
|
def apply(module, name: str, dim: int) -> 'WeightNorm':
|
|
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
|
|
|
|
for hook in module._forward_pre_hooks.values():
|
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
|
raise RuntimeError(f"Cannot register two weight_norm hooks on the same parameter {name}")
|
|
|
|
if dim is None:
|
|
dim = -1
|
|
|
|
fn = WeightNorm(name, dim)
|
|
|
|
weight = getattr(module, name)
|
|
if isinstance(weight, UninitializedParameter):
|
|
raise ValueError(
|
|
'The module passed to `WeightNorm` can\'t have uninitialized parameters. '
|
|
'Make sure to run the dummy forward before applying weight normalization')
|
|
# remove w from parameter list
|
|
del module._parameters[name]
|
|
|
|
# add g and v as new parameters and express w as g/||v|| * v
|
|
module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
|
|
module.register_parameter(name + '_v', Parameter(weight.data))
|
|
setattr(module, name, fn.compute_weight(module))
|
|
|
|
# recompute weight before every forward()
|
|
module.register_forward_pre_hook(fn)
|
|
|
|
return fn
|
|
|
|
def remove(self, module: Module) -> None:
|
|
weight = self.compute_weight(module)
|
|
delattr(module, self.name)
|
|
del module._parameters[self.name + '_g']
|
|
del module._parameters[self.name + '_v']
|
|
setattr(module, self.name, Parameter(weight.data))
|
|
|
|
def __call__(self, module: Module, inputs: Any) -> None:
|
|
setattr(module, self.name, self.compute_weight(module))
|
|
|
|
|
|
T_module = TypeVar('T_module', bound=Module)
|
|
|
|
def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module:
|
|
r"""Apply weight normalization to a parameter in the given module.
|
|
|
|
.. math::
|
|
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
|
|
|
|
Weight normalization is a reparameterization that decouples the magnitude
|
|
of a weight tensor from its direction. This replaces the parameter specified
|
|
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
|
|
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
|
|
Weight normalization is implemented via a hook that recomputes the weight
|
|
tensor from the magnitude and direction before every :meth:`~Module.forward`
|
|
call.
|
|
|
|
By default, with ``dim=0``, the norm is computed independently per output
|
|
channel/plane. To compute a norm over the entire weight tensor, use
|
|
``dim=None``.
|
|
|
|
See https://arxiv.org/abs/1602.07868
|
|
|
|
.. warning::
|
|
|
|
This function is deprecated. Use :func:`torch.nn.utils.parametrizations.weight_norm`
|
|
which uses the modern parametrization API. The new ``weight_norm`` is compatible
|
|
with ``state_dict`` generated from old ``weight_norm``.
|
|
|
|
Migration guide:
|
|
|
|
* The magnitude (``weight_g``) and direction (``weight_v``) are now expressed
|
|
as ``parametrizations.weight.original0`` and ``parametrizations.weight.original1``
|
|
respectively. If this is bothering you, please comment on
|
|
https://github.com/pytorch/pytorch/issues/102999
|
|
|
|
* To remove the weight normalization reparametrization, use
|
|
:func:`torch.nn.utils.parametrize.remove_parametrizations`.
|
|
|
|
* The weight is no longer recomputed once at module forward; instead, it will
|
|
be recomputed on every access. To restore the old behavior, use
|
|
:func:`torch.nn.utils.parametrize.cached` before invoking the module
|
|
in question.
|
|
|
|
Args:
|
|
module (Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
dim (int, optional): dimension over which to compute the norm
|
|
|
|
Returns:
|
|
The original module with the weight norm hook
|
|
|
|
Example::
|
|
|
|
>>> m = weight_norm(nn.Linear(20, 40), name='weight')
|
|
>>> m
|
|
Linear(in_features=20, out_features=40, bias=True)
|
|
>>> m.weight_g.size()
|
|
torch.Size([40, 1])
|
|
>>> m.weight_v.size()
|
|
torch.Size([40, 20])
|
|
|
|
"""
|
|
WeightNorm.apply(module, name, dim)
|
|
return module
|
|
|
|
|
|
def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
|
|
r"""Remove the weight normalization reparameterization from a module.
|
|
|
|
Args:
|
|
module (Module): containing module
|
|
name (str, optional): name of weight parameter
|
|
|
|
Example:
|
|
>>> m = weight_norm(nn.Linear(20, 40))
|
|
>>> remove_weight_norm(m)
|
|
"""
|
|
for k, hook in module._forward_pre_hooks.items():
|
|
if isinstance(hook, WeightNorm) and hook.name == name:
|
|
hook.remove(module)
|
|
del module._forward_pre_hooks[k]
|
|
return module
|
|
|
|
raise ValueError(f"weight_norm of '{name}' not found in {module}")
|