Commit Graph

23 Commits

Author SHA1 Message Date
Xuehai Pan
5a80d2df84 [BE] enable UFMT for torch/nn/utils (#128595)
Part of #123062

- #123062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128595
Approved by: https://github.com/Skylion007
2024-06-13 18:34:57 +00:00
Aaron Orenstein
27f9d3b0a1 Flip default value for mypy disallow_untyped_defs [8/11] (#127845)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127845
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843, #127844
2024-06-08 18:49:56 +00:00
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
Xuehai Pan
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
Senthil Kumar N
3f62531191 Fix: docstring errors in torch.nn.utils - parametrizations.py/prune.py/weight_norm.py (#113021)
Fixes #112631. As the previous PR #112943 has some accidental merge and it resolved through this PR.

- torch/nn/utils/parametrizations.py
**Before - 6**
```
torch\nn\utils\parametrizations.py:1 at module level:
        D100: Missing docstring in public module
torch\nn\utils\parametrizations.py:23 in private function `_make_orthogonal`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\parametrizations.py:23 in private function `_make_orthogonal`:
        D210: No whitespaces allowed surrounding docstring text
torch\nn\utils\parametrizations.py:178 in public function `orthogonal`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\parametrizations.py:309 in public function `weight_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\parametrizations.py:483 in public function `spectral_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
6
```
**After - 1**
```
torch\nn\utils\parametrizations.py:1 at module level:
        D100: Missing docstring in public module
1
```
- torch/nn/utils/prune.py
**Before - 100**
```
torch\nn\utils\prune.py:1 at module level:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch\nn\utils\prune.py:1 at module level:
        D400: First line should end with a period (not 's')
torch\nn\utils\prune.py:13 in public class `BasePruningMethod`:
        D204: 1 blank line required after class docstring (found 0)
torch\nn\utils\prune.py:21 in public method `__call__`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:21 in public method `__call__`:
        D400: First line should end with a period (not ')')
torch\nn\utils\prune.py:34 in public method `compute_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:34 in public method `compute_mask`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch\nn\utils\prune.py:53 in public method `apply_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:53 in public method `apply_mask`:
        D400: First line should end with a period (not 'g')
torch\nn\utils\prune.py:74 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:74 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:74 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:200 in public method `prune`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:200 in public method `prune`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:200 in public method `prune`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch\nn\utils\prune.py:229 in public method `remove`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:229 in public method `remove`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:229 in public method `remove`:
        D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes')
torch\nn\utils\prune.py:256 in public class `PruningContainer`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:264 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:277 in public method `add_pruning_method`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:297 in public method `__len__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:300 in public method `__iter__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:303 in public method `__getitem__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:307 in public method `compute_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:307 in public method `compute_mask`:
        D400: First line should end with a period (not 's')
torch\nn\utils\prune.py:307 in public method `compute_mask`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\prune.py:335 in private nested function `_combine_masks`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:335 in private nested function `_combine_masks`:
        D400: First line should end with a period (not ':')
torch\nn\utils\prune.py:404 in public class `Identity`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:404 in public class `Identity`:
        D400: First line should end with a period (not 'e')
torch\nn\utils\prune.py:410 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:416 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:416 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:416 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:442 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:447 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:469 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:469 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:469 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:486 in public class `L1Unstructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:486 in public class `L1Unstructured`:
        D400: First line should end with a period (not 's')
torch\nn\utils\prune.py:498 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:503 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:527 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:527 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:527 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:564 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:571 in public method `compute_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:571 in public method `compute_mask`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch\nn\utils\prune.py:634 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:634 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:634 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:653 in public class `LnStructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:653 in public class `LnStructured`:
        D400: First line should end with a period (not 'r')
torch\nn\utils\prune.py:669 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:677 in public method `compute_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:677 in public method `compute_mask`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
torch\nn\utils\prune.py:747 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:747 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:747 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:779 in public class `CustomFromMask`:
        D101: Missing docstring in public class
torch\nn\utils\prune.py:783 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:786 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:793 in public method `apply`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:793 in public method `apply`:
        D400: First line should end with a period (not 'd')
torch\nn\utils\prune.py:793 in public method `apply`:
        D401: First line should be in imperative mood (perhaps 'Add', not 'Adds')
torch\nn\utils\prune.py:806 in public function `identity`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:806 in public function `identity`:
        D400: First line should end with a period (not 'e')
torch\nn\utils\prune.py:806 in public function `identity`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\prune.py:839 in public function `random_unstructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:839 in public function `random_unstructured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:874 in public function `l1_unstructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:874 in public function `l1_unstructured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:916 in public function `random_structured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:916 in public function `random_structured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:955 in public function `ln_structured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:955 in public function `ln_structured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:1000 in public function `global_unstructured`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1000 in public function `global_unstructured`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:1120 in public function `custom_from_mask`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1120 in public function `custom_from_mask`:
        D400: First line should end with a period (not '`')
torch\nn\utils\prune.py:1154 in public function `remove`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1154 in public function `remove`:
        D400: First line should end with a period (not 'e')
torch\nn\utils\prune.py:1154 in public function `remove`:
        D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes')
torch\nn\utils\prune.py:1184 in public function `is_pruned`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1184 in public function `is_pruned`:
        D400: First line should end with a period (not 'r')
torch\nn\utils\prune.py:1211 in private function `_validate_pruning_amount_init`:
        D401: First line should be in imperative mood (perhaps 'Validate', not 'Validation')
torch\nn\utils\prune.py:1243 in private function `_validate_pruning_amount`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1243 in private function `_validate_pruning_amount`:
        D400: First line should end with a period (not 'e')
torch\nn\utils\prune.py:1243 in private function `_validate_pruning_amount`:
        D401: First line should be in imperative mood (perhaps 'Validate', not 'Validation')
torch\nn\utils\prune.py:1265 in private function `_validate_structured_pruning`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1265 in private function `_validate_structured_pruning`:
        D400: First line should end with a period (not '-')
torch\nn\utils\prune.py:1265 in private function `_validate_structured_pruning`:
        D401: First line should be in imperative mood (perhaps 'Validate', not 'Validation')
torch\nn\utils\prune.py:1284 in private function `_compute_nparams_toprune`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1284 in private function `_compute_nparams_toprune`:
        D400: First line should end with a period (not 'a')
torch\nn\utils\prune.py:1308 in private function `_validate_pruning_dim`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1308 in private function `_validate_pruning_dim`:
        D400: First line should end with a period (not ':')
torch\nn\utils\prune.py:1318 in private function `_compute_norm`:
        D205: 1 blank line required between summary line and description (found 0)
torch\nn\utils\prune.py:1318 in private function `_compute_norm`:
        D400: First line should end with a period (not 'n')
100
```
**After - 14**
```
torch\nn\utils\prune.py:266 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:299 in public method `__len__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:302 in public method `__iter__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:305 in public method `__getitem__`:
        D105: Missing docstring in magic method
torch\nn\utils\prune.py:411 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:445 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:450 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:502 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:507 in public method `compute_mask`:
        D102: Missing docstring in public method
torch\nn\utils\prune.py:570 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:677 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:790 in public class `CustomFromMask`:
        D101: Missing docstring in public class
torch\nn\utils\prune.py:794 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\prune.py:797 in public method `compute_mask`:
        D102: Missing docstring in public method
14
```
- torch/nn/utils/weight_norm.py
**Before - 10**
```
torch\nn\utils\weight_norm.py:1 at module level:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch\nn\utils\weight_norm.py:1 at module level:
        D400: First line should end with a period (not '8')
torch\nn\utils\weight_norm.py:12 in public class `WeightNorm`:
        D101: Missing docstring in public class
torch\nn\utils\weight_norm.py:16 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\weight_norm.py:23 in public method `compute_weight`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:29 in public method `apply`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:59 in public method `remove`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:66 in public method `__call__`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:73 in public function `weight_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
torch\nn\utils\weight_norm.py:137 in public function `remove_weight_norm`:
        D401: First line should be in imperative mood (perhaps 'Remove', not 'Removes')
10
```
**After - 6**
```
torch\nn\utils\weight_norm.py:10 in public class `WeightNorm`:
        D101: Missing docstring in public class
torch\nn\utils\weight_norm.py:14 in public method `__init__`:
        D107: Missing docstring in __init__
torch\nn\utils\weight_norm.py:21 in public method `compute_weight`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:27 in public method `apply`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:57 in public method `remove`:
        D102: Missing docstring in public method
torch\nn\utils\weight_norm.py:64 in public method `__call__`:
        D102: Missing docstring in public method
6
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113021
Approved by: https://github.com/lezcano
2023-11-06 17:24:32 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Justin Chu
4cc1745b13 [BE] f-stringify torch/ and scripts (#105538)
This PR is a follow up on the pyupgrade series to convert more strings to use f-strings using `flynt`.

- https://docs.python.org/3/reference/lexical_analysis.html#f-strings
- https://pypi.org/project/flynt/

Command used:

```
flynt torch/ -ll 120
flynt scripts/ -ll 120
flynt tools/ -ll 120
```

and excluded `collect_env.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105538
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-21 19:35:24 +00:00
Edward Z. Yang
ba962fefea Add parametrization version of weight_norm (#103001)
This done in the ordinary way, but also:

* Deprecation warning for the old API, and a migration guide
* Backwards compatibility for state_dict loading the old weight_norm
* Test for pickling and deepcopy, which was the motivating reason

weight_norm is still used by HuggingFace Wav2Vec2.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103001
Approved by: https://github.com/albanD
2023-06-06 13:14:43 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
PyTorch MergeBot
9db3c517de Add __all__ for torch.nn.modules, torch.distributed.elastic, torch.nn.utils submodules (#80240)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80240
Approved by: https://github.com/rohan-varma
2022-06-27 17:11:12 +00:00
Emilio Castillo
d38a71d579 torch.nn.modules.LazyModuleMixin and torch.nn.LazyLinear (Shape Inference II) (#44538)
Summary:
Retake on https://github.com/pytorch/pytorch/issues/40493 after all the feedback from albanD

This PR implements the generic Lazy mechanism and a sample `LazyLinear` layer with the `UninitializedParameter`.

The main differences with the previous PR are two;
Now `torch.nn.Module` remains untouched.
We don't require an explicit initialization or a dummy forward pass before starting the training or inference of the actual module. Making this much simpler to use from the user side.

As we discussed offline, there was the suggestion of not using a mixin, but changing the `__class__` attribute of `LazyLinear` to become `Linear` once it's completely initialized. While this can be useful, by the time being we need `LazyLinear` to be a `torch.nn.Module` subclass since there are many checks that rely on the modules being instances of `torch.nn.Module`.
This can cause problems when we create complex modules such as
```
class MyNetwork(torch.nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.conv = torch.nn.Conv2d(20, 4, 2)
        self.linear = torch.nn.LazyLinear(10)
    def forward(self, x):
        y = self.conv(x).clamp(min=0)
        return self.linear(y)
```
Here, when the __setattr__ function is called at the time LazyLinear is registered, it won't be added to the child modules of `MyNetwork`, so we have to manually do it later, but currently there is no way to do such thing as we can't access the parent module from LazyLinear once it becomes the Linear module. (We can add a workaround to this if needed).

TODO:

Add convolutions once the design is OK
Fix docstrings

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44538

Reviewed By: ngimel

Differential Revision: D24162854

Pulled By: albanD

fbshipit-source-id: 6d58dfe5d43bfb05b6ee506e266db3cf4b885f0c
2020-10-19 13:13:54 -07:00
Nikita Shulga
1c6ace87d1 Embed torch.nn typing annotations (#43044)
Summary:
Delete several .pyi files and embed annotations from those files in respective .py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43044

Reviewed By: ezyang

Differential Revision: D23123234

Pulled By: malfet

fbshipit-source-id: 4ba361cc84402352090523924b0035e100ba48b1
2020-08-14 13:24:58 -07:00
Michela Paganini
d37a4861b8 Explicit attribute setting for pruning and weight_norm upon reparam removal (#34170)
Summary:
To address one of the problems with RNNs that emerged in https://github.com/pytorch/pytorch/issues/33618, I modified the `remove` methods in `torch.nn.utils.prune` and `torch.nn.utils.weight_norm` to make an explicit call to `setattr`, which, in `rnn.py` directly modifies `_flat_weights` (https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py#L96) to include the new element.

This is important so that `_flat_weights` can reflect the presence of the `Parameter` after the (pruning or weight norm) reparametrization is removed. Without this, the weight in `_flat_weights` would remain a tensor, as originally set by the reparametrization.

Simple testing is added, which depends on the current naming scheme for the LSTM module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34170

Differential Revision: D21265965

Pulled By: mickypaganini

fbshipit-source-id: 29de4a6b17052d42ccfe67c8560b7f83c20fd09d
2020-04-29 09:01:59 -07:00
ZhuBaohe
19a6de328f Correct docstring of vision/init functions
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17351

Differential Revision: D14276355

Pulled By: soumith

fbshipit-source-id: 9b572b6a04eeb1e44cd93961edac76ed10f7b24e
2019-03-01 11:40:23 -08:00
Tongzhou Wang
2cd912bcc2 Fix more spectral norm bugs (#13350)
Summary:
Problems with SN and DP after #12671 :
1. in eval mode, `weight_orig` is not getting correct gradient #12737 .

    Fix: keep `v` vector around as a buffer and always calculate `W = W_orig / (u @ W_orig @ v)` even in eval.

2. in training mode, the `weight` buffer of the parallelized module is never updated, if someone touches `weight_orig` and/or `weight` and makes them not sharing storage. So in `eval` the weight used is wrong.

    Fix: Make `weight` not a buffer anymore and always calculate it as above.

3. #12671 changed SN to update `u` in-place to make DP work correctly, but then it breaks backward through two forwards (e.g., the common GAN loss `D(real) - D(fake)`) because the vectors needed to backprop the 1st forward is changed in the 2nd forward.

    Fix: This PR clones `u` and `v` before using them.

To maintain BC, I added a hook interface for producing and loading state_dict. This is ugly and we should really have better interface for spectral_norm. But for the purpose to fix this issue, I make this patch. Even if we have a better interface, BC mechanism for legacy loading legacy state_dict still needs to be done.

cc The controller you requested could not be found. crcrpar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13350

Differential Revision: D12931044

Pulled By: SsnL

fbshipit-source-id: 8be6f934eaa62414d76d2c644dedd7e1b7eb31ef
2018-11-06 19:16:13 -08:00
Edward Yang
74197c7115 Restore support for dim=None on WeightNorm. (#11661)
Summary:
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11661

Reviewed By: veenix

Differential Revision: D9826799

Pulled By: ezyang

fbshipit-source-id: 9eec57bb27a365406669e412f6eb88741b22ed3d
2018-09-14 07:39:43 -07:00
Michael Carilli
a3036b3bb3 Fused weightnorm for ATen (#10842)
Summary:
This PR contains a C++ implementation of weight norm.  The user-side exposure of weight norm through torch.nn.utils.weight_norm is unchanged.

If running on the GPU, and the norm is requested over the first or last dimension of the weight tensor, the forward pass is carried out using the fused kernels I wrote for our Fairseq GTC hero run, which offer superior performance to primitive ops and superior numerical stability when running in FP16.  In the common case that the backward pass is not itself constructing a graph (ie not attempting to set up double backward) the backward pass will be carried out using another fused kernel.  If the backward pass is constructing a graph, an alternate code path is taken, which does the math using differentiable primitive ops. In this way, the implementation allows double backward, even if the fused kernel was used in forward (although in this case, you don't benefit from the performance and stability of the fused backward kernel).

If running on the CPU, or if norming over an interior dim, the forward pass is carried out using double-differentiable primitive ops.

Figuring out how to generate all the right plumbing for this was tricky, but it was a fun experience learning how the autogenerator works and how the graph is constructed.  Thanks to colesbury for useful guidance on this front.

I do have a few lingering questions:

- Should I unify my return statements (ie by default-constructing Tensors outside if blocks and using operator= within)?
- What is the significance of `non_blocking` when calling e.g. `auto norms = saved_norms.to(saved_g.type().scalarType(), non_blocking=True/False);`?  I am currently omitting `non_blocking`, so it defaults to False, but I didn't see any associated synchronizes on the timeline, so I'm wondering what it means.
- Is there an "official" mapping from at::ScalarTypes to corresponding accumulate types, as there are for the PODs + Half in [AccumulateType.h](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/AccumulateType.h)?  I looked for an equivalent mapping for ScalarTypes, didn't find one, and ended up rigging it myself (`  at::ScalarType AccType = g.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float : g.type().scalarType();`).
- Are sparse tensors a concern?  Should I include another check for sparse tensors in the `_weight_norm` entry point, and send those along the fallback CPU path as well?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10842

Differential Revision: D9735531

Pulled By: ezyang

fbshipit-source-id: 24431d46532cf5503876b3bd450d5ca775b3eaee
2018-09-12 13:55:27 -07:00
Tongzhou Wang
1c01eabd3c
Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
SsnL
de1f4e69dd raw text (#3327) 2017-10-28 01:24:02 +05:30
Sam Gross
661beb3345 Speed-up weight_norm over the right-most dim (#2431)
When weight-normalizing over the right-most dimension, combine all
dimensions to the left into a single dim. This avoids two extra
transposes.
2017-08-16 18:04:18 -04:00
Sam Gross
ea563c1df1 Make weight norm pickleable (#2066) 2017-07-12 17:21:22 -04:00
Sam Gross
2c038f2074 Add weight normalization implementation (#1945)
* Add weight normalization implementation

This adds forward "pre-hooks" which get called before the module's
forward() method. Weight norm is implemented as a hook which calculates
the weight variable from the weight_g and weight_v every iteration.

Based on @rtqichen implementation.

* Specify return type
2017-06-30 15:41:40 -04:00