Commit Graph

699 Commits

Author SHA1 Message Date
pilot-j
9062e429db Fixed docstring errors in torch/nn/functional.py (Docathon H2) (#112856)
Fixes #112597
### Output:
**BEFORE:**
```functional.py:1 at module level:
        D400: First line should end with a period (not 'e')
functional.py:438 in public function `fractional_max_pool2d_with_indices`:
        D400: First line should end with a period (not ')')
functional.py:537 in public function `fractional_max_pool3d_with_indices`:
        D400: First line should end with a period (not ')')
functional.py:646 in public function `max_pool1d_with_indices`:
        D400: First line should end with a period (not ')')
functional.py:732 in public function `max_pool2d_with_indices`:
        D400: First line should end with a period (not ')')
functional.py:818 in public function `max_pool3d_with_indices`:
        D400: First line should end with a period (not ')')
functional.py:932 in public function `max_unpool1d`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
functional.py:968 in public function `max_unpool2d`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
functional.py:1000 in public function `max_unpool3d`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
functional.py:1031 in public function `lp_pool2d`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:1031 in public function `lp_pool2d`:
        D400: First line should end with a period (not 'f')
functional.py:1031 in public function `lp_pool2d`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1056 in public function `lp_pool1d`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:1056 in public function `lp_pool1d`:
        D400: First line should end with a period (not 'f')
functional.py:1056 in public function `lp_pool1d`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1077 in public function `adaptive_max_pool1d_with_indices`:
        D400: First line should end with a period (not ')')
functional.py:1119 in public function `adaptive_max_pool2d_with_indices`:
        D400: First line should end with a period (not ')')
functional.py:1163 in public function `adaptive_max_pool3d_with_indices`:
        D400: First line should end with a period (not ')')
functional.py:1220 in public function `adaptive_avg_pool2d`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:1220 in public function `adaptive_avg_pool2d`:
        D400: First line should end with a period (not 'f')
functional.py:1220 in public function `adaptive_avg_pool2d`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1237 in public function `adaptive_avg_pool3d`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:1237 in public function `adaptive_avg_pool3d`:
        D400: First line should end with a period (not 'f')
functional.py:1237 in public function `adaptive_avg_pool3d`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1255 in public function `dropout`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:1255 in public function `dropout`:
        D400: First line should end with a period (not 't')
functional.py:1275 in public function `alpha_dropout`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1287 in public function `dropout1d`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:1287 in public function `dropout1d`:
        D400: First line should end with a period (not ',')
functional.py:1325 in public function `dropout2d`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:1325 in public function `dropout2d`:
        D400: First line should end with a period (not ',')
functional.py:1369 in public function `dropout3d`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:1369 in public function `dropout3d`:
        D400: First line should end with a period (not ',')
functional.py:1408 in public function `feature_alpha_dropout`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:1408 in public function `feature_alpha_dropout`:
        D400: First line should end with a period (not ',')
functional.py:1466 in public function `relu`:
        D400: First line should end with a period (not 'r')
functional.py:1466 in public function `relu`:
        D402: First line should not be the function's "signature"
functional.py:1491 in public function `glu`:
        D400: First line should end with a period (not 'r')
functional.py:1491 in public function `glu`:
        D402: First line should not be the function's "signature"
functional.py:1516 in public function `hardtanh`:
        D400: First line should end with a period (not 'r')
functional.py:1516 in public function `hardtanh`:
        D402: First line should not be the function's "signature"
functional.py:1542 in public function `relu6`:
        D400: First line should end with a period (not 'r')
functional.py:1542 in public function `relu6`:
        D402: First line should not be the function's "signature"
functional.py:1558 in public function `elu`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1582 in public function `selu`:
        D400: First line should end with a period (not 'r')
functional.py:1582 in public function `selu`:
        D402: First line should not be the function's "signature"
functional.py:1611 in public function `celu`:
        D400: First line should end with a period (not 'r')
functional.py:1611 in public function `celu`:
        D402: First line should not be the function's "signature"
functional.py:1638 in public function `leaky_relu`:
        D400: First line should end with a period (not 'r')
functional.py:1638 in public function `leaky_relu`:
        D402: First line should not be the function's "signature"
functional.py:1688 in public function `rrelu`:
        D400: First line should end with a period (not 'r')
functional.py:1688 in public function `rrelu`:
        D402: First line should not be the function's "signature"
functional.py:1755 in public function `tanhshrink`:
        D400: First line should end with a period (not 'r')
functional.py:1755 in public function `tanhshrink`:
        D402: First line should not be the function's "signature"
functional.py:1767 in public function `softsign`:
        D400: First line should end with a period (not 'r')
functional.py:1767 in public function `softsign`:
        D402: First line should not be the function's "signature"
functional.py:1806 in public function `softmin`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1832 in public function `softmax`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1868 in public function `gumbel_softmax`:
        D401: First line should be in imperative mood (perhaps 'Sample', not 'Samples')
functional.py:1930 in public function `log_softmax`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1969 in public function `tanh`:
        D400: First line should end with a period (not 'r')
functional.py:1969 in public function `tanh`:
        D402: First line should not be the function's "signature"
functional.py:1980 in public function `sigmoid`:
        D400: First line should end with a period (not 'r')
functional.py:1980 in public function `sigmoid`:
        D402: First line should not be the function's "signature"
functional.py:1990 in public function `hardsigmoid`:
        D400: First line should end with a period (not 'n')
functional.py:1990 in public function `hardsigmoid`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2057 in public function `silu`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:2057 in public function `silu`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2081 in public function `mish`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:2081 in public function `mish`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2100 in public function `hardswish`:
        D400: First line should end with a period (not ':')
functional.py:2100 in public function `hardswish`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2136 in public function `embedding`:
        D202: No blank lines allowed after function docstring (found 1)
functional.py:2136 in public function `embedding`:
        D401: First line should be in imperative mood; try rephrasing (found 'A')
functional.py:2254 in public function `embedding_bag`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:2254 in public function `embedding_bag`:
        D400: First line should end with a period (not 'e')
functional.py:2254 in public function `embedding_bag`:
        D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
functional.py:2462 in public function `batch_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2507 in public function `instance_norm`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:2507 in public function `instance_norm`:
        D400: First line should end with a period (not 'a')
functional.py:2507 in public function `instance_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2540 in public function `layer_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2554 in public function `group_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2567 in public function `local_response_norm`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:2567 in public function `local_response_norm`:
        D400: First line should end with a period (not 'f')
functional.py:2567 in public function `local_response_norm`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2611 in public function `ctc_loss`:
        D401: First line should be in imperative mood; try rephrasing (found 'The')
functional.py:2679 in public function `nll_loss`:
        D401: First line should be in imperative mood; try rephrasing (found 'The')
functional.py:2895 in public function `kl_div`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:2895 in public function `kl_div`:
        D400: First line should end with a period (not 's')
functional.py:2895 in public function `kl_div`:
        D401: First line should be in imperative mood; try rephrasing (found 'The')
functional.py:2978 in public function `cross_entropy`:
        D401: First line should be in imperative mood; try rephrasing (found 'This')
functional.py:3069 in public function `binary_cross_entropy`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:3069 in public function `binary_cross_entropy`:
        D400: First line should end with a period (not 't')
functional.py:3069 in public function `binary_cross_entropy`:
        D401: First line should be in imperative mood; try rephrasing (found 'Function')
functional.py:3139 in public function `binary_cross_entropy_with_logits`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:3139 in public function `binary_cross_entropy_with_logits`:
        D400: First line should end with a period (not 't')
functional.py:3139 in public function `binary_cross_entropy_with_logits`:
        D401: First line should be in imperative mood; try rephrasing (found 'Function')
functional.py:3211 in public function `smooth_l1_loss`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:3211 in public function `smooth_l1_loss`:
        D400: First line should end with a period (not 'e')
functional.py:3211 in public function `smooth_l1_loss`:
        D401: First line should be in imperative mood; try rephrasing (found 'Function')
functional.py:3251 in public function `huber_loss`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:3251 in public function `huber_loss`:
        D400: First line should end with a period (not 'e')
functional.py:3251 in public function `huber_loss`:
        D401: First line should be in imperative mood; try rephrasing (found 'Function')
functional.py:3282 in public function `l1_loss`:
        D400: First line should end with a period (not 'r')
functional.py:3282 in public function `l1_loss`:
        D402: First line should not be the function's "signature"
functional.py:3313 in public function `mse_loss`:
        D400: First line should end with a period (not 'r')
functional.py:3313 in public function `mse_loss`:
        D402: First line should not be the function's "signature"
functional.py:3346 in public function `margin_ranking_loss`:
        D400: First line should end with a period (not 'r')
functional.py:3346 in public function `margin_ranking_loss`:
        D402: First line should not be the function's "signature"
functional.py:3382 in public function `hinge_embedding_loss`:
        D400: First line should end with a period (not 'r')
functional.py:3382 in public function `hinge_embedding_loss`:
        D402: First line should not be the function's "signature"
functional.py:3411 in public function `multilabel_margin_loss`:
        D400: First line should end with a period (not 'r')
functional.py:3411 in public function `multilabel_margin_loss`:
        D402: First line should not be the function's "signature"
functional.py:3439 in public function `soft_margin_loss`:
        D400: First line should end with a period (not 'r')
functional.py:3439 in public function `soft_margin_loss`:
        D402: First line should not be the function's "signature"
functional.py:3462 in public function `multilabel_soft_margin_loss`:
        D400: First line should end with a period (not 'r')
functional.py:3462 in public function `multilabel_soft_margin_loss`:
        D402: First line should not be the function's "signature"
functional.py:3510 in public function `cosine_embedding_loss`:
        D400: First line should end with a period (not 'r')
functional.py:3510 in public function `cosine_embedding_loss`:
        D402: First line should not be the function's "signature"
functional.py:3543 in public function `multi_margin_loss`:
        D400: First line should end with a period (not 'r')
functional.py:3543 in public function `multi_margin_loss`:
        D402: First line should not be the function's "signature"
functional.py:3708 in public function `upsample` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3713 in public function `upsample` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3718 in public function `upsample` (skipping F811):
        D205: 1 blank line required between summary line and description (found 0)
functional.py:3718 in public function `upsample` (skipping F811):
        D400: First line should end with a period (not 'n')
functional.py:3783 in private function `_is_integer`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:3794 in public function `interpolate` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3799 in public function `interpolate` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3804 in public function `interpolate` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3809 in public function `interpolate` (skipping F811):
        D103: Missing docstring in public function
functional.py:3821 in public function `interpolate` (skipping F811,B950):
        D205: 1 blank line required between summary line and description (found 0)
functional.py:3821 in public function `interpolate` (skipping F811,B950):
        D400: First line should end with a period (not 'n')
functional.py:4062 in public function `upsample_nearest` (skipping F811):
        D103: Missing docstring in public function
functional.py:4067 in public function `upsample_nearest` (skipping F811):
        D103: Missing docstring in public function
functional.py:4100 in public function `upsample_bilinear` (skipping F811):
        D103: Missing docstring in public function
functional.py:4107 in public function `upsample_bilinear` (skipping F811):
        D103: Missing docstring in public function
functional.py:4114 in public function `upsample_bilinear` (skipping F811):
        D103: Missing docstring in public function
functional.py:4121 in public function `upsample_bilinear` (skipping F811):
        D103: Missing docstring in public function
functional.py:4174 in public function `grid_sample`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:4174 in public function `grid_sample`:
        D400: First line should end with a period (not 'e')
functional.py:4315 in public function `affine_grid`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:4315 in public function `affine_grid`:
        D400: First line should end with a period (not 'f')
functional.py:4315 in public function `affine_grid`:
        D401: First line should be in imperative mood (perhaps 'Generate', not 'Generates')
functional.py:4608 in public function `triplet_margin_loss`:
        D200: One-line docstring should fit on one line with quotes (found 3)
functional.py:4608 in public function `triplet_margin_loss`:
        D400: First line should end with a period (not 's')
functional.py:4643 in public function `triplet_margin_with_distance_loss`:
        D200: One-line docstring should fit on one line with quotes (found 3)
functional.py:4705 in public function `normalize`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
functional.py:4733 in public function `assert_int_or_pair`:
        D103: Missing docstring in public function
functional.py:4743 in public function `unfold`:
        D401: First line should be in imperative mood (perhaps 'Extract', not 'Extracts')
functional.py:4773 in public function `fold`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:4773 in public function `fold`:
        D400: First line should end with a period (not 'g')
functional.py:4773 in public function `fold`:
        D401: First line should be in imperative mood (perhaps 'Combine', not 'Combines')
functional.py:4800 in private function `_in_projection_packed`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:4800 in private function `_in_projection_packed`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
functional.py:4867 in private function `_in_projection`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:4867 in private function `_in_projection`:
        D400: First line should end with a period (not 'y')
functional.py:4867 in private function `_in_projection`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
functional.py:5128 in public function `multi_head_attention_forward`:
        D205: 1 blank line required between summary line and description (found 0)
functional.py:5128 in public function `multi_head_attention_forward`:
        D400: First line should end with a period (not ':')
160
```

**AFTER:**

```
functional.py:3709 in public function `upsample` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3714 in public function `upsample` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3798 in public function `interpolate` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3803 in public function `interpolate` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3808 in public function `interpolate` (skipping F811,B950):
        D103: Missing docstring in public function
functional.py:3813 in public function `interpolate` (skipping F811):
        D103: Missing docstring in public function
functional.py:4068 in public function `upsample_nearest` (skipping F811):
        D103: Missing docstring in public function
functional.py:4073 in public function `upsample_nearest` (skipping F811):
        D103: Missing docstring in public function
functional.py:4106 in public function `upsample_bilinear` (skipping F811):
        D103: Missing docstring in public function
functional.py:4113 in public function `upsample_bilinear` (skipping F811):
        D103: Missing docstring in public function
functional.py:4120 in public function `upsample_bilinear` (skipping F811):
        D103: Missing docstring in public function
functional.py:4127 in public function `upsample_bilinear` (skipping F811):
        D103: Missing docstring in public function
functional.py:4742 in public function `assert_int_or_pair`:
        D103: Missing docstring in public function
13
```

The file contained several docstring errors. I have fixed all of them(hopefully) and have tried to improve the over all readability of the code. For most part, I have included relevant description of functions (referred from official PyTorch Docs). In some cases where functions are purely mathematical or it is difficult to give one line description, I have just included references.

For testing, I relied on local system and created a separate file. For final edits, I directly changed the contents of forked repo as visible already.

Kindly review @svekars @subramen @kit1980

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112856
Approved by: https://github.com/kit1980
2023-11-13 22:16:49 +00:00
giacomo
7b28f8c5ea Better error message when applying interpolation on non-4D tensors (#113459)
Fixes #113445

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113459
Approved by: https://github.com/albanD
2023-11-10 21:06:51 +00:00
Eric Zhang
468a73f0e3 Support Numpy ints in the torch.nn.functional.interpolate dtype check (#110778)
In https://github.com/pytorch/pytorch/pull/99243, a check was added to ensure the `size` only contained integers.

This PR updates the check to also include numpy integers based on this comment (cc @kit1980): https://github.com/pytorch/pytorch/pull/99243#issuecomment-1646736646. Similar to the other commenter, I also ran into issues where existing software broke due to this after upgrading to PT2.1:

```
                if not torch.jit.is_scripting():
                    if not all(_is_integer(x) for x in size):
>                       raise TypeError(
                            "expected size to be one of int or Tuple[int] or Tuple[int, int] or "
                            f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}"
                        )
E                       TypeError: expected size to be one of int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], but got size with types [<class 'numpy.int64'>, <class 'numpy.int64'>]

/conda-env/lib/python3.8/site-packages/torch/nn/functional.py:3924: TypeError
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110778
Approved by: https://github.com/mikaylagawarecki
2023-10-10 01:46:33 +00:00
Mikayla Gawarecki
abd83ce180 Small fix in SDPA docstring codeblock (#109086)
Fix https://github.com/pytorch/pytorch/issues/109072

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109086
Approved by: https://github.com/drisspg
2023-09-12 16:48:46 +00:00
FFFrog
969bf8a054 Fix the document of torch.nn.functional.conv2d (#107851)
Fixes #107692

Fix the document of torch.nn.functional.conv2d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107851
Approved by: https://github.com/mikaylagawarecki
2023-08-24 18:02:03 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Mikayla Gawarecki
1317dbf176 Reland "Add nn.CircularPad{*}d for consistency + fix no_batch_dim support (#106148)" (#106632)
Previous one was reverted because the PR stacked under which added error-checking to Pad variants https://github.com/pytorch/pytorch/pull/106147 was reverted as internally some people pass 2D inputs to ZeroPad2d (which should actually take 3d or 4d inputs :) but there wasn't actually anything this PR was breaking according to my understanding

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106632
Approved by: https://github.com/albanD
2023-08-07 20:10:25 +00:00
PyTorch MergeBot
dfcfd5cedb Revert "Add nn.CircularPad{*}d for consistency + fix no_batch_dim support (#106148)"
This reverts commit 87d2536971.

Reverted https://github.com/pytorch/pytorch/pull/106148 on behalf of https://github.com/malfet due to Reverting as dependent PR https://github.com/pytorch/pytorch/pull/106147 was reverted as well ([comment](https://github.com/pytorch/pytorch/pull/106148#issuecomment-1662344543))
2023-08-02 14:46:00 +00:00
Mikayla Gawarecki
87d2536971 Add nn.CircularPad{*}d for consistency + fix no_batch_dim support (#106148)
Fixes #105749 https://github.com/pytorch/pytorch/issues/95320

(tldr is that input should always be `[N, C, H, (W, D])` where only H, W and D dimensions get circular padding, so the 2D case where user wants both dimensions to be padded --> they should `.unsqueeze(0)` (as is the case for `Reflection/ReplicationPad`) but we didn't document this for circular padding. [This seems to be the old docstring](277b05014a/torch/nn/functional.py (L4689)) that was somehow lost.

Fixes no_batch_dim support https://github.com/pytorch/pytorch/issues/104860

- Adds missing documentation for circular padding
- Adds missing CircularPad modules
- Migrates legacy test_nn tests from circular padding to ModuleInfo
- Adds no_batch_dim support + sample inputs that test this

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106148
Approved by: https://github.com/albanD
ghstack dependencies: #106325, #106147
2023-08-01 12:49:58 +00:00
FFFrog
9a1cdcb8a0 Format: fixing multiple string concatenation in single line (#106013)
Fixing multiple string concatenation in single line
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106013
Approved by: https://github.com/albanD
2023-07-26 18:39:18 +00:00
lezcano
9bde7f4e27 Fix the docs for cosine_similarity (#104772)
The behaviour of `cosine_similarity` was subtly changed in
https://github.com/pytorch/pytorch/pull/31378, but the docs were not
updated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104772
Approved by: https://github.com/albanD, https://github.com/svekars
2023-07-26 09:23:09 +00:00
Justin Chu
4cc1745b13 [BE] f-stringify torch/ and scripts (#105538)
This PR is a follow up on the pyupgrade series to convert more strings to use f-strings using `flynt`.

- https://docs.python.org/3/reference/lexical_analysis.html#f-strings
- https://pypi.org/project/flynt/

Command used:

```
flynt torch/ -ll 120
flynt scripts/ -ll 120
flynt tools/ -ll 120
```

and excluded `collect_env.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105538
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-21 19:35:24 +00:00
Justin Chu
79c5e33349 [BE] Enable ruff's UP rules and autoformat nn/ mps/ and torch/ (#105436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105436
Approved by: https://github.com/malfet, https://github.com/albanD
2023-07-21 07:38:46 +00:00
drisspg
2ee440054b Small tweaks to SDPA docs (#104749)
Fixes #104652

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 2d61112</samp>

No summary available (An error occurred while summarizing these changes: Gave up after 3 retries: Failed to read error response)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104749
Approved by: https://github.com/mikaylagawarecki
2023-07-10 21:01:45 +00:00
yewentao
d3ba8901d8 Adding precision issue note docs for functional.interpolate (#104622)
Fixes #104157

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104622
Approved by: https://github.com/ezyang
2023-07-05 16:20:57 +00:00
vfdev
4ab140902b [docs] Fixed typo in grid_sample doctring (#104406)
Fixed a small typo in grid_sample doctring:

<img width="265" alt="image" src="https://github.com/pytorch/pytorch/assets/2459423/1d2dd7a2-895a-4683-9d9f-a4d1d9d1a4a7">

- https://pytorch.org/docs/main/generated/torch.nn.functional.grid_sample.html

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104406
Approved by: https://github.com/mikaylagawarecki, https://github.com/svekars
2023-06-29 19:44:54 +00:00
Ryan Smith
6bda97e2c1 Raise type error message for interpolate if size contains non-integer elements (#99243)
Raise type error message for interpolate when output size is a tuple containing elements that are not `int`

Fixes #98287

Check is only performed if `size` is an instance of `list` or `tuple`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99243
Approved by: https://github.com/Skylion007, https://github.com/Neilblaze, https://github.com/MovsisyanM, https://github.com/albanD
2023-06-23 00:48:45 +00:00
MysticalMusings
f1f13a35b0 Fix GELU-related docstring formatting (#102845)
The docstring about GELU seems formatted incorrectly. The original docstring about GELU is rendered as below:

$$ \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) $$

where the square root of which part is confusing.

I double-checked the formula, which should be:

$$ \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3))) $$

where round brackets in resource code should be brace brackets.

> _formula in [original paper](https://arxiv.org/abs/1606.08415)_
> ![Snipaste_2023-06-03_00-43-49](https://github.com/pytorch/pytorch/assets/39690782/22511c4e-2f20-4a16-9bda-4c182a360160)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102845
Approved by: https://github.com/mikaylagawarecki
2023-06-08 20:19:03 +00:00
cviviers
81c181dc01 Update BCEWithLogitsLoss pos_weight description in documentation (#101567)
Fixes #82496 and #65702

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101567
Approved by: https://github.com/mikaylagawarecki
2023-05-19 21:23:21 +00:00
Edward Z. Yang
c567748e16 Make interpolate_bilinear deterministic using decomposition (#101115)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101115
Approved by: https://github.com/ngimel
2023-05-11 22:48:01 +00:00
Joel Schlosser
bd9d50a3fc Remove future deprecation warning from kl_div docs (#96541)
Fixes #95687
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96541
Approved by: https://github.com/albanD
2023-05-05 23:01:21 +00:00
soulitzer
6585d76f0f [docs] nn.functional.embedding: Note expected discrepancy between numerical and analytical gradients (#99181)
*

Fixes https://github.com/pytorch/pytorch/issues/93950
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99181
Approved by: https://github.com/albanD
2023-04-22 02:30:53 +00:00
mega-optimus
06081ac8f3 Update docstring of torch.nn.functional.normalize() (#99512)
Fixes #99125

torch.nn.functional.normalize() already supports dim=tuple(int), but the docstring says int only.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99512
Approved by: https://github.com/albanD
2023-04-21 16:45:24 +00:00
ts
dbf0db958f Fix torch.nn.FractionalMaxPool2d output_size error (#99507)
Fixes #99148 , raising an error if output_ratio's size > 2.

Justification for changes:

If an output size is not specified but an output ratio is, we call fractional_max_pool2d_with_indices. We then generate the value of output_size based on the first two integers of the output_ratio (line ~480 of torch.nn.functional.py).

Thus, we should raise a value error in the case that the user passes an output_ratio (instead of an output_size) and the number of elements in output_ratio exceeds two. We must raise an error before calling torch._C._nn.franctional_max_pool2d as the value of output_size passed into torch._C._nn.fractional_max_pool2d is guaranteed to be of size 2 (as the existing code generates it from the first two indices of the passed in ratio).

I would be happy to iterate on this if there are any issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99507
Approved by: https://github.com/mikaylagawarecki
2023-04-21 14:38:25 +00:00
Kazuaki Ishizaki
a531a464fd Fix typos under torch/nn directory (#97594)
This PR fixes typos in comments of `.py` files under `torch/nn` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97594
Approved by: https://github.com/dagitses, https://github.com/kit1980
2023-04-10 22:07:15 +00:00
Mikayla Gawarecki
73b06a0268 Fix rendering of arguments for nn.functional ops that use boolean_dispatch (#98092)
Fix #97982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98092
Approved by: https://github.com/albanD
2023-04-03 21:17:43 +00:00
Aaron Gokaslan
597b558c51 [BE]: Update flake8 and plugins and fix bugs (#97795)
Update flake8 and flake8-plugins in lintrunner to a modern version. Enables more checks and makes flake8 checks significantly faster. Added a few additional rule ignores that will need to be fixed in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97795
Approved by: https://github.com/alexsio27444, https://github.com/janeyx99, https://github.com/ezyang
2023-03-28 23:51:55 +00:00
Michael Gschwind
c757647dd8 [Better Transformer] make is_causal a hint and force attn_mask to be set on is_causal=True in F.MHA (#97214)
Summary:
This fixes an issue raised in [is_causal parameter in torch.nn.TransformerEncoderLayer.forward does not work #96941](https://github.com/pytorch/pytorch/issues/96941) where results computed with is_causal do not properly reflect causal masking.

In PyTorch 2.0, Accelerated PT Transformers added the is_causal parameter to legacy nn.Transformer* and nn.MHA APIs aligned with and intended to engage the is_causal parameter of the new scaled_dot_product_attention (SDPA) operator.

At present is_causal works differently for Transformer* modules, the nn.MHA and F.MHA:
* The nn.Transformer* modules treat is_causal as an optional indicator about the format of attn_mask. This is because some layers (such as the CLIP layer use the attention mask in the layer, and thus the attn_mask was a required feature.)
* Initially, nn.MHA and F.MHA were defined to align with F.SDPA in behavior: a user may specify either the attention mask, or is_causal, but not both.  It seemed to make sense at the time to align SDPA and MHA, esp since there was a larger overlap of parameters which have since changed, e.g., with the removal of need_weights from SDPA. (See below for why this makes sense.)

Unfortunately, this does not work because of how MHA was changed to handle the need_weights parameter.  When need_weights is present, we do not (any more) call SDPA because support for need_weights was removed from SDPA before the release.  The rationale is that need_weights defeats all optimization at the foundation of SDPA performance.  Having the flag might thus mislead users into thinking they get good performance and have them disappointed when they enable a legacy feature of MHA which massively degrades performance.  (They might not think anything of enabling that, because it is on by default in MHA today, which leads to more  issues.)

Since SDPA does not (no longer) support need_weights, we need to pick a separate path which implements attention using a set of discrete operations that allocates a tensor for weights.  Alas, this code path does not have support for is_causal, because attention is implemented as matmul and using the attention mask.  Thus, is_causal has no impact.  (A substantially similar situation arises with how kpm is implemented today because Nested Tensors are not supported by torch.compile() in 2.0)

This problem was masked because all uses of legacy nn.MHA (and F.MHA) come through nn.Transformer* which called self-attention (i.e., nn.MHA) only ever with the attention mask attn_mask, and never with is_causal, a missed optimization opportunit that would have been addressed in a future performance update.

Regrettably, always calling nn.MHA with attn_mask prevented diagnosing of the issue of not having a suitable attention mask when need_weights support was dropped from SDPA and a discrete implementation of attention was added for that scenario, and for the execution path with key_padding_mask.

We have two options to address this issue:

Solution 1: Whenever nn.MHA and F.MHA are executed with is_causal set, we internally create a causal mask at significant expense of allocating a tensor and filling it with a triangular causal matrix.  This increases memory usage, and runtime, for allocating a causal mask.  To add insult to injury, in all current (and likely future) execution scenarios, MHA is called by a model using the nn.Transformer API which already has that matrix and passes it from nn.module to nn.module.  Then the passing in of attn_mask has to be suppressed by nn.TransformerEncoderLayer, only for nn.MHA to immediately allocate the very same tensor again to satisfy the requirement to have an attention mask for the computation. (We expect new use cases to use SDPA directly.)

Solution 2: We align the behavior of nn.MHA and F.MHA with the rest of the existing nn.Transformer API, and require the attention mask to be passed into nn.MHA in addition to is_causal as an optional indicator about the nature of the attention mask rather than as an alternative to attn_mask.  Then, when we choose the code path for processing MHA with need_weights or a key_padding_mask, we have the attn_mask passed down through the nn.Transformer* hierarchy, without the added overhead of allocating an attention mask as in scenario 1.

This PR implements solution 2 which offers better performance and in retrospect aligns MHA better with the rest of the Transformer modules as the definition of SDPA evolved into a more streamlined high-performance operator.  It ostensibly changes how is_causal works, by requiring the attention mask to be specified.  However, as described here, and as shown in the submitted issue, is_causal is not working as intended today, so it requires a change regardless.

In that sense, a change in API does not occur per-se, as the current implementation is not working, and a change has to occur either way to resolve the submitted issue, breaking any use cases that depend on the current implementation.  Checks exist (and more can be added) that flag any scenarios where is_causal is passed as True, but no attention mask is provided, ensuring that there's not quiet change from even the faulty behavior present in 2.0.

As  an upside, the present implementation will improve performance by addressing the passing of the is_causal flag from Transformer modules to MHA, speeding up training for these examples, e.g., finetuning BERT, RoBERTa, XLM-R models.

Differential Revision: D44245725

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97214
Approved by: https://github.com/albanD
2023-03-25 01:36:30 +00:00
CedricPicron
cf0ba1b9c0 Use L1 loss for Smooth L1 loss with beta=0 (#97022)
Fixes #96813.

Comments:

1. Wasn't able to test since tools/nightly.py does not allow for GPU build (and I don't want to build from scratch).
2. In theory, the bug (i.e. NaNs) can still occur when beta is very small (e.g. `beta=1e-50`), but not sure whether anybody cares.
3. Some checks within the smooth_l1_loss C++ code could be changed to check for `beta > 0` instead of `beta >= 0`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97022
Approved by: https://github.com/jbschlosser
2023-03-24 19:10:32 +00:00
Michael Gschwind
61cb544397 Align mask formatting of both masks more closely (#96286)
Summary: Align mask formatting of both masks more closely

Test Plan: sandcastle & github

Differential Revision: D43878634

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96286
Approved by: https://github.com/cpuhrsch
2023-03-11 02:18:05 +00:00
Driss Guessous
11aab72dc9 [SDPA] Add an optional scale kwarg (#95259)
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the q@k.T step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.

Will reduce the complexity of: #94729 and has been asked for by a couple of users.

# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `Q@K.T * (scale)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95259
Approved by: https://github.com/cpuhrsch
2023-03-08 18:07:40 +00:00
Michael Gschwind
03b6e6979c Transformers: fix src and key padding mask bool regression (#96009)
Summary: fix src and pad mask bool regression

This fixes a regression introduced previously with #92733. That PR unified testing of masks to remove Byte Tensors as permissible mask, introduced mask compatibility check, and mask conversion to FP mask.  The problem addressed in this PR was that after the first mask had been converted, a check for mask compatibility would fail.

Test Plan: sandcastle & github

Differential Revision: D43782858

Fixes  https://github.com/pytorch/pytorch/issues/95702

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96009
Approved by: https://github.com/malfet
2023-03-05 01:50:46 +00:00
soulitzer
e5c2a35d83 Add check that embedding_bag's weight is 2D (#94931)
Fixes https://github.com/pytorch/pytorch/issues/94445

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94931
Approved by: https://github.com/albanD
2023-02-16 02:37:47 +00:00
Driss Guessous
70026aaad6 [SDPA] update type hint for scaled_dot_product_attention and documentation (#94008)
# Summary
- Adds type hinting support for SDPA
- Updates the documentation adding warnings and notes on the context manager
- Adds scaled_dot_product_attention to the non-linear activation function section of nn.functional docs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94008
Approved by: https://github.com/cpuhrsch
2023-02-10 18:02:43 +00:00
Natalia Gimelshein
a5daea69fb teach inductor to handle floor (#94341)
Per title, happen when there's upsampling with non-integer scale.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94341
Approved by: https://github.com/ezyang
2023-02-10 11:21:57 +00:00
PyTorch MergeBot
6007874bbb Revert "teach inductor to handle floor (#94341)"
This reverts commit e7df9aaec8.

Reverted https://github.com/pytorch/pytorch/pull/94341 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but the CudaTest failure looks related.  It fails on both PR and trunk e7df9aaec8
2023-02-09 19:31:08 +00:00
Natalia Gimelshein
e7df9aaec8 teach inductor to handle floor (#94341)
Per title, happen when there's upsampling with non-integer scale.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94341
Approved by: https://github.com/ezyang
2023-02-09 17:09:35 +00:00
milesial
6c555b29a8 MHA optimizations (#93234)
Slight perf optimizations for regular MHA by reducing the number of kernels called

Before:
![image](https://user-images.githubusercontent.com/30204471/215349212-172c6364-9e3c-4fd1-92b6-8ddd9931613e.png)

After:
![image](https://user-images.githubusercontent.com/30204471/215349247-021dd9e6-f6ca-40a2-8de8-0805af001f69.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93234
Approved by: https://github.com/drisspg
2023-02-03 15:18:35 +00:00
Driss Guessous
3df0e26e20 [SDPA] Remove private version and only utilize public version (#94004)
# Summary
Due to internal failures we needed to keep the private call in torch.nn.mha. This PR undoes this change, so that we call the public function and remove the private function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94004
Approved by: https://github.com/cpuhrsch, https://github.com/albanD
2023-02-03 08:12:09 +00:00
103yiran
d9117b93fb unsqueeze only when dim = 3 (#91052)
unsqueeze is not necessary if use view

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91052
Approved by: https://github.com/albanD
2023-01-31 16:28:23 +00:00
Driss Guessous
ca8f5e177a Use the old aten underscored function for Predictor (#93096)
Summary:
Errors reported via https://fb.prod.workplace.com/groups/1405155842844877/permalink/6644919482201794/

The problem is that the scriptable op set between predictor and the latest build of master is different.

Test Plan: Sandcastle testing

Differential Revision: D42786069

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93096
Approved by: https://github.com/mikekgfb
2023-01-28 03:14:18 +00:00
Michael Gschwind
7265f60ad0 Regularize mask handling for attn_mask and key_padding_mask (#92733)
Summary:
Regularize mask handling for attn_mask and key_padding_mask
* Update documentation to remove reference to byte masks (which were deprecated long ago)
* Introduce check and warn about deprecation if attn_mask and key_padding_mask types mismatch
* Convert all masks to float before combining
* Combine by adding

Test Plan: sandcastle & github CI

Differential Revision: D42653215

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92733
Approved by: https://github.com/ngimel, https://github.com/drisspg
2023-01-24 14:12:05 +00:00
Driss Guessous
df14650f0b [SDPA] Update SDPA API and make function Public (#92189)
# Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.

## Changes
### API
Previously the the function signature was:
`scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)`
Updated signature:
`scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor`

This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.

#### Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g.  (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.

The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.

Discussed with folks at FAIR/Xformers and +1 this API change.

#### Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189
Approved by: https://github.com/cpuhrsch
2023-01-23 20:50:46 +00:00
Michael Gschwind
af589b3d1f switch causal mask for is_causal flag (#91171)
Summary: switch causal mask for is_causal flag

Test Plan: sandcastle & github

Differential Revision: D42089340

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91171
Approved by: https://github.com/wushirong, https://github.com/drisspg
2022-12-30 17:24:58 +00:00
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Joel Schlosser
3d8834bdbf SymIntify F.interpolate() with recompute_scale_factor=True (#91318)
This PR makes the minor changes necessary to get `F.interpolate()` working with symbolic shapes when `recompute_scale_factor=True` + adds `OpInfo` samples to test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91318
Approved by: https://github.com/ezyang
2022-12-29 01:42:56 +00:00
Michael Gschwind
512ec181ec Introduce causal mask (#90508)
Summary: Introduce causal mask

This PR introduces a causal mask option _causal_mask (as well as causal mask detection if attn_mask is provided), since current custom kernels do not support arbitrary masks.

Test Plan: sandcastle & github ci/cd

Differential Revision: D41723137

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90508
Approved by: https://github.com/albanD
2022-12-16 21:39:42 +00:00
Driss Guessous
78bdb858f9 Call _sdp_attention in nn.functional.mha (#89470)
# Summary
Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89470
Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
2022-12-02 19:46:22 +00:00
PyTorch MergeBot
f1415b8cb6 Revert "Call _sdp_attention in nn.functional.mha (#89470)"
This reverts commit 4d7ec30220.

Reverted https://github.com/pytorch/pytorch/pull/89470 on behalf of https://github.com/jeanschmidt due to breaking internal builds
2022-11-30 16:16:24 +00:00
PyTorch MergeBot
618a585f6c Revert "replace double transpose with single permute in nn.f.mha (#89847)"
This reverts commit b9afa92827.

Reverted https://github.com/pytorch/pytorch/pull/89847 on behalf of https://github.com/jeanschmidt due to Need to revert this commit as it is causing conflict when reverting #89470
2022-11-30 16:03:48 +00:00
Driss Guessous
b9afa92827 replace double transpose with single permute in nn.f.mha (#89847)
# Summary

I forgot about permute which was exactly what I wanted. Quick perf bump
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89847
Approved by: https://github.com/cpuhrsch, https://github.com/albanD
2022-11-29 22:18:42 +00:00
Driss Guessous
4d7ec30220 Call _sdp_attention in nn.functional.mha (#89470)
# Summary
Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89470
Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
2022-11-29 03:02:10 +00:00
foram-chandra
e19a7165fd [nn] Remove deprecation warning from nn.functional.{tanh, sigmoid} (#86905)
Fixes #65909

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86905
Approved by: https://github.com/albanD, https://github.com/kit1980
2022-11-24 00:34:26 +00:00
Nikita Karetnikov
0a1a53083e [primTorch] Enable regex error testing for some refs (#87765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87765
Approved by: https://github.com/mruberry
2022-11-23 23:36:27 +00:00
David Boetius
b652fbc57a Fix torch.nn.functional.gelu docstring formatting (#89061)
The docstring of `torch.nn.functional.gelu` is formatted incorrectly, so that part of the math isn't rendered and there are extra blocks when there shouldn't: https://pytorch.org/docs/stable/generated/torch.nn.functional.gelu.html

I didn't build the docs, so I am not 100% sure that I got the formatting right, but I am confident.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89061
Approved by: https://github.com/bdhirsh, https://github.com/kit1980
2022-11-18 01:57:41 +00:00
Ryan Spring
534ae6ae47 [primTorch] Implement group norm reference (#87054)
Add group norm reference
Split from #81191
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87054
Approved by: https://github.com/mruberry
2022-11-11 01:08:20 +00:00
Kazuaki Ishizaki
2ddefbdc3c Fix typos used in documents under torch directory (#88300)
This PR fixes typos, in comments of Python files, that are found from a search box at https://pytorch.org/docs/master/search.html

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88300
Approved by: https://github.com/lezcano
2022-11-02 09:38:13 +00:00
Rui Zhu
4b757f4633 Assert if padding mask type is unexpected (#86353) (#87106)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86353

Fix the issue described in
https://github.com/pytorch/pytorch/issues/86120

Test Plan: buck test mode/opt caffe2/test:test_transformers -- test_train_with_long_type_pad

Differential Revision: D40129968

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87106
Approved by: https://github.com/malfet
2022-10-20 16:01:54 +00:00
Andrew M. James
db65909255 [Docs] Update mm family ops and F.linear to note limited sparse support. (#86220)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86220
Approved by: https://github.com/cpuhrsch
2022-10-18 19:55:18 +00:00
Nikita Karetnikov
d56017a14f [primTorch] Add ref for triplet_margin_loss, improve triplet_margin_with_distance_loss (#85614)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85614
Approved by: https://github.com/lezcano, https://github.com/mruberry
2022-10-12 18:37:58 +00:00
lezcano
787028cadb Implement col2im decomposition and fix im2col and add a few preconditions (#85541)
As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85541
Approved by: https://github.com/jansel
2022-09-30 09:31:53 +00:00
Srikumar Sastry
c8776dca6a Remove extra with in value error exception statement (#84713)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84713
Approved by: https://github.com/ngimel
2022-09-27 18:43:39 +00:00
Driss Guessous
253ffbf28b Exposing native _scaled_dot_product_attention to torch.nn (#85044)
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
2022-09-22 16:30:16 +00:00
PyTorch MergeBot
a3dc338ee1 Revert "Exposing native _scaled_dot_product_attention to torch.nn (#85044)"
This reverts commit 9fdd8a8b7f.

Reverted https://github.com/pytorch/pytorch/pull/85044 on behalf of https://github.com/huydhn due to This breaks CUDA 10.2 in trunk. We are deprecating CUDA 10.2, but it is still here in the mean time
2022-09-21 08:34:51 +00:00
Driss Guessous
9fdd8a8b7f Exposing native _scaled_dot_product_attention to torch.nn (#85044)
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
2022-09-21 03:09:08 +00:00
joncrall
b136f3f310 More doctest refinements. (#83317)
Follow up to #82797

Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way.

@ezyang @vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317
Approved by: https://github.com/ezyang
2022-08-22 20:07:26 +00:00
Edward Z. Yang
cb64b558ee Add spaces so example is flake8 compatible (#83420)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83420
Approved by: https://github.com/jbschlosser
2022-08-15 21:39:57 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
Alex Li
1fedd40424 Update cross entropy documentation to metion logits clearly (#82538)
### Description
Improved the documentation for cross entropy as it is a common point of confusion.

### Issue
#82081

### Testing
I did not test this change as it is tiny and documentation-only
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82538
Approved by: https://github.com/jbschlosser
2022-08-08 22:24:28 +00:00
ProGamerGov
357b7d589c Fix docstring inconsistencies: string -> str, boolean -> bool (#82410)
### Description

Throughout the PyTorch docs and codebase, the `string` type in docstrings is referred to by two separate names. This leads to inconsistent docs, like you can see here: https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d

This PR fixes this issue by ensuring that all mentions of the string type in docstrings, are using the same format that Sphinx generates hyperlinks for.

### Testing
No testing should be required for this change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82410
Approved by: https://github.com/jbschlosser
2022-07-28 21:29:57 +00:00
kylematoba
66cf1b6459 correct argument name in docs (#81485)
Recently introduced `average_attn_weights` argument is documented incorrectly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81485
Approved by: https://github.com/albanD
2022-07-20 20:07:16 +00:00
soulitzer
bd75b2fea1 Add ref for nn.functional.prelu (#79768)
TODO:
- not sure if these error-inputs work for all devices (awaiting CI)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79768
Approved by: https://github.com/mruberry
2022-07-07 17:04:47 +00:00
Albert Chung
b4ed13ea0f Update docstring for scale_factor in torch.nn.functional.interpolate. (#80807)
Fixes #80786

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80807
Approved by: https://github.com/ezyang
2022-07-04 04:36:16 +00:00
Joel Benjamin Schlosser
5953fd9133 Revert behavior of Dropout2d on 3D inputs to 1D channel-wise dropout behavior & warn
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79549

Approved by: https://github.com/ngimel, https://github.com/albanD
2022-06-15 14:56:43 +00:00
Joel Benjamin Schlosser
2d73c8e6e0 Add Dropout1d module
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79545

Approved by: https://github.com/ngimel, https://github.com/albanD
2022-06-15 14:39:07 +00:00
PyTorch MergeBot
3556457dd2 Revert "kl_div: fix for grads wrt target, double backward, forward-over-reverse AD support. (#79007)"
This reverts commit 72ad222cff.

Reverted https://github.com/pytorch/pytorch/pull/79007 on behalf of https://github.com/janeyx99 due to Broke test_fn_fwgrad_bwgrad_nn_functional_kl_div_cpu_float64 on trunk https://hud.pytorch.org/minihud?name_filter=pull%20/%20linux-xenial-py3.7-clang7-asan%20/%20test%20(default,%202,%205,%20linux.2xlarge)
2022-06-09 13:07:03 +00:00
Nikita Vedeneev
72ad222cff kl_div: fix for grads wrt target, double backward, forward-over-reverse AD support. (#79007)
Fixes https://github.com/pytorch/pytorch/issues/78867,
fixes https://github.com/pytorch/pytorch/issues/65466.
Adds forward-over-reverse AD support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79007
Approved by: https://github.com/soulitzer, https://github.com/jbschlosser
2022-06-09 09:06:52 +00:00
Rohit Goswami
5a95b20d0f DOC: Harmonize ELU documentation with the module doc (#78909)
Fixes #77055 by simply referring to the module docs as noted in the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78909
Approved by: https://github.com/albanD
2022-06-06 14:14:11 +00:00
samdow
b7cb4eae6b Fix embedding jvp support by making embedding_renorm ignore forward mode AD (#78560)
On functorch, we started seeing [embedding forward mode fail](https://github.com/pytorch/functorch/pull/816). From looking at it, we figured out that recently [embedding got forward mode support enabled](369d9f4137) and then doing forward mode with embedding and [max_norm doesn't work with gradcheck](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8877-L8881), so it's not checked.

What was happening is that `embedding_renorm` was setting `torch.no_grad()` which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the `embedding_renorm` call. This makes it so that we don't use forward mode during the embedding_renorm call
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78560
Approved by: https://github.com/soulitzer, https://github.com/albanD
2022-06-03 19:14:51 +00:00
PyTorch MergeBot
d578197747 Revert "Fix embedding jvp support by making embedding_renorm ignore forward mode AD (#78560)"
This reverts commit ce7c7bb2a9.

Reverted https://github.com/pytorch/pytorch/pull/78560 on behalf of https://github.com/malfet due to broke XLA (on CI and trunk), see ce7c7bb2a9
2022-06-02 17:40:34 +00:00
samdow
ce7c7bb2a9 Fix embedding jvp support by making embedding_renorm ignore forward mode AD (#78560)
On functorch, we started seeing [embedding forward mode fail](https://github.com/pytorch/functorch/pull/816). From looking at it, we figured out that recently [embedding got forward mode support enabled](369d9f4137) and then doing forward mode with embedding and [max_norm doesn't work with gradcheck](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8877-L8881), so it's not checked.

What was happening is that `embedding_renorm` was setting `torch.no_grad()` which only turns off the backwards mode AD so functorch's jvp tests were still using forward mode AD during the `embedding_renorm` call. This makes it so that we don't use forward mode during the embedding_renorm call
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78560
Approved by: https://github.com/soulitzer, https://github.com/albanD
2022-06-02 13:40:21 +00:00
Kshiteej K
4e1f41f66a [docs][nn] conv: complex support note (#78351)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78351
Approved by: https://github.com/anjali411, https://github.com/jbschlosser
2022-05-26 20:33:36 +00:00
Natalia Gimelshein
362525724b type promote clamp (#77035)
Fixes #76630
When clamp(Tensor, Tensor) is structured, big parts of this PR won't be needed, but for now let's fix type promotion to make behavior more regular.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77035
Approved by: https://github.com/mruberry
2022-05-09 05:54:17 +00:00
vitrioil
f92cddd890 Removed direct doc formatting
Fixes #76034

This does not make python remove all `__doc__` because in some places `__doc__` is assigned to a string.

Example:
04b3313379/torch/nn/modules/conv.py (L174-L233)

Since there are quite a few of these, I will add all of them together in this PR later. (Basically still a lot of docstring will persist even with `-OO` enabled.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76619
Approved by: https://github.com/albanD
2022-05-02 14:14:33 +00:00
Yuge Zhang
3ac27e78ca Fix typehint of multi_head_attention_forward
Fixes #76169

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76170
Approved by: https://github.com/jbschlosser
2022-04-27 13:47:43 +00:00
Peter Bell
cb37e7a080 Remove F.pad python implementation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73433

Approved by: https://github.com/albanD, https://github.com/jbschlosser
2022-04-23 00:13:20 +00:00
vitrioil
29b004be7a Corrected documentation for supported padding
Fixes #72521

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76117
Approved by: https://github.com/jbschlosser
2022-04-20 17:36:01 +00:00
Mike Ruberry
b09769992f Improves the OpInfo out= tests
Edit: OpInfos separated into their own PRs to debug an ASAN failure that doesn't identify the failing test properly. This PR now just updates the out tests.

Adds OpInfos for:

- nn.functional.smooth_l1_loss
- nn.functional.l1_loss
- nn.functional.pdist
- nn.functional.binary_cross_entropy
- nn.functional.triplet_margin_loss
- nn.functional.triplet_margin_with_distance_loss
- nn.functional.max_unpool{1, 2, 3}D
- nn.functional.alpha_dropout
- nn.functional.soft_margin_loss
- nn.functional.multilabel_soft_margin_loss
- nn.functional.multilabel_margin_loss
- nn.functional.multi_margin_loss
- nn.functional.margin_ranking_loss

These OpInfos were taken from https://github.com/pytorch/pytorch/pull/67560, https://github.com/pytorch/pytorch/pull/67823, https://github.com/pytorch/pytorch/pull/68625, and https://github.com/pytorch/pytorch/pull/67079. The sample input update from https://github.com/pytorch/pytorch/pull/67017 is also rolled into this PR.

cc @zou3519 @nikitaved @pmeier @vfdev-5 @dagitses
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75782
Approved by: https://github.com/ngimel
2022-04-15 06:16:01 +00:00
Edward Z. Yang
0a1bc5f501 Miscellaneous __torch_function__ fixes
I figured these out by unconditionally turning on a no-op torch function
mode on the test suite and then fixing errors as they showed up.  Here's
what I found:

- _parse_to failed internal assert when __torch_function__'ed because it
  claims its name is "to" to the argument parser; added a name override
  so we know how to find the correct name

- Infix operator magic methods on Tensor did not uniformly handle
  __torch_function__ and TypeError to NotImplemented.  Now, we always
  do the __torch_function__ handling in
  _wrap_type_error_to_not_implemented and your implementation of
  __torch_function__ gets its TypeErrors converted to NotImplemented
  (for better or for worse; see
  https://github.com/pytorch/pytorch/issues/75462 )

- A few cases where code was incorrectly testing if a Tensor was
  Tensor-like in the wrong way, now use is_tensor_like (in grad
  and in distributions).  Also update docs for has_torch_function to
  push people to use is_tensor_like.

- is_grads_batched was dropped from grad in handle_torch_function, now
  fixed

- Report that you have a torch function even if torch function is
  disabled if a mode is enabled.  This makes it possible for a mode
  to return NotImplemented, pass to a subclass which does some
  processing and then pass back to the mode even after the subclass
  disables __torch_function__ (so the tensors are treated "as if"
  they are regular Tensors).  This brings the C++ handling behavior
  in line with the Python behavior.

- Make the Python implementation of overloaded types computation match
  the C++ version: when torch function is disabled, there are no
  overloaded types (because they all report they are not overloaded).

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75484

Approved by: https://github.com/zou3519
2022-04-11 16:52:16 +00:00
Scott Wolchok
87f40ee6d6 [PyTorch] Existing MHA: fuse the attn_mask addition (#73219)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73219

Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse?
ghstack-source-id: 152549975

Test Plan:
CI, review

Ran benchmark_transformers.par mha --batch-size 64 --max-sequence-length 128 --avg-sequence-length 256 --large --use-real-data-distribution --use-mask
and looked at the PT time number

```
before:
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True             PT Time: 1.24ms, NativePT Time: 1000000000.00ms, HF Time: 1.10ms,             PT FLOPS: 59.07TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.46TFLOP/s
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True             PT Time: 1.23ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms,             PT FLOPS: 59.57TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.75TFLOP/s
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True             PT Time: 1.24ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms,             PT FLOPS: 58.87TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.77TFLOP/s

after:
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True             PT Time: 1.22ms, NativePT Time: 1000000000.00ms, HF Time: 1.10ms,             PT FLOPS: 60.07TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.51TFLOP/s
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True             PT Time: 1.22ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms,             PT FLOPS: 59.80TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.69TFLOP/s
B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True             PT Time: 1.21ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms,             PT FLOPS: 60.21TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.86TFLOP/s
```

Inspected a Kineto trace and confirmed that an elementwise add was fused into baddbmm.

Additional opportunity: I see a copy_ inside baddbmm that wasn't happening with the bmm path and I'm not sure why. Perhaps something went wrong with the structured kernels port by ezyang?

Reviewed By: ezyang

Differential Revision: D34160547

fbshipit-source-id: 78d406fb035e6f3bf13af2c9443a886eada35ac4
(cherry picked from commit aaffc39b24058742cb9ae42105f95b3eafe9d7f5)
2022-04-04 20:31:22 +00:00
Peter Bell
7f051b4d2b Implement F.pad in ATen
This moves the C++ torch pad function into ATen proper. Once the
forward-compatibility period is over, the python interface can use
this directly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73431

Approved by: https://github.com/ezyang
2022-04-01 01:10:12 +00:00
Davit Kobaladze
8e12d2bf25 fixes torch.jit.script lp_pool bug. (#73287)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/60258

I used the solution proposed in https://github.com/pytorch/pytorch/issues/61275.  His solution failed unit tests and there was no progress after 08/07/2021. I'm willing to fix problems if they arise during CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73287

Reviewed By: navahgar, zou3519

Differential Revision: D35057812

Pulled By: eellison

fbshipit-source-id: 8e82e9f73b9536979aecf476c5c65336cdffc93a
(cherry picked from commit e85e912a4edec1111623c5cbbba4171fe3bc5b1d)
2022-03-28 23:16:07 +00:00
Peter Bell
f86bb2d6e4 Implement _pad_circular in ATen
Closes #44459

This migrates the python implementation of `_pad_circular` to ATen and
removes the old C++ implementation that had diverged from python.

Note that `pad` can't actually use this until the
forward-compatibility period is over.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73410

Approved by: https://github.com/ezyang
2022-03-25 02:09:01 +00:00
Kushashwa Ravi Shrimali
452c26bbeb Fix functional.max_poolNd warning spam in the CI
Fixes https://github.com/pytorch/pytorch/issues/71257.

Warnings have been removed, please see [this](https://github.com/pytorch/pytorch/pull/71258#issuecomment-1058503649) comment.

cc: @Lezcano @jbschlosser @zou3519
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71258
Approved by: https://github.com/Lezcano, https://github.com/jbschlosser
2022-03-04 18:42:23 +00:00
Scott Wolchok
28339ddc25 [PyTorch] Hit fused addmm path in linear() for existing MHA (#72871)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72871

We do this same trick in the native MHA implementation; backport it for purposes of fair comparison.
ghstack-source-id: 149526858

Test Plan: CI

Reviewed By: ngimel

Differential Revision: D34176090

fbshipit-source-id: 8b578c29c4dcf0d85bae74dfbbb82db9a8f32dc7
(cherry picked from commit fd50170935)
2022-02-22 19:33:46 +00:00
Joel Schlosser
f670179c0a Fix doc regressions for various modules and functional forms (#73014)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73014

Fixes #72501
Fixes #72502
Fixes #72503
Fixes #72504
Fixes #72505
Fixes #72506
Fixes #72507
Fixes #72509
Fixes #72510

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D34305640

Pulled By: jbschlosser

fbshipit-source-id: 62f341633fdb0316eaa346cf7247865290eb830a
(cherry picked from commit 8362d264e7)
2022-02-17 22:40:18 +00:00
Vitaly Fedyunin
81fbeea760 Add docstrings to native_channel_shuffle (#72919)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72919

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D34274717

Pulled By: VitalyFedyunin

fbshipit-source-id: fa42f91ef2335e2594b19ef65d914c711f7a94fd
(cherry picked from commit a6f6fe9112)
2022-02-17 02:33:08 +00:00
Ryan Spring
4f8b986e28 Implement Tanh Gelu Approximation (#61439)
Summary:
1. Implements https://github.com/pytorch/pytorch/issues/39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - https://github.com/pytorch/xla/pull/3039

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61439

Reviewed By: VitalyFedyunin

Differential Revision: D33894937

Pulled By: jbschlosser

fbshipit-source-id: b65e8fb6ea66168af8f34f45ed50e92737a33851
(cherry picked from commit 6e986f91a9)
2022-02-14 03:40:32 +00:00