Fixes#112597
### Output:
**BEFORE:**
```functional.py:1 at module level:
D400: First line should end with a period (not 'e')
functional.py:438 in public function `fractional_max_pool2d_with_indices`:
D400: First line should end with a period (not ')')
functional.py:537 in public function `fractional_max_pool3d_with_indices`:
D400: First line should end with a period (not ')')
functional.py:646 in public function `max_pool1d_with_indices`:
D400: First line should end with a period (not ')')
functional.py:732 in public function `max_pool2d_with_indices`:
D400: First line should end with a period (not ')')
functional.py:818 in public function `max_pool3d_with_indices`:
D400: First line should end with a period (not ')')
functional.py:932 in public function `max_unpool1d`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
functional.py:968 in public function `max_unpool2d`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
functional.py:1000 in public function `max_unpool3d`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
functional.py:1031 in public function `lp_pool2d`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:1031 in public function `lp_pool2d`:
D400: First line should end with a period (not 'f')
functional.py:1031 in public function `lp_pool2d`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1056 in public function `lp_pool1d`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:1056 in public function `lp_pool1d`:
D400: First line should end with a period (not 'f')
functional.py:1056 in public function `lp_pool1d`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1077 in public function `adaptive_max_pool1d_with_indices`:
D400: First line should end with a period (not ')')
functional.py:1119 in public function `adaptive_max_pool2d_with_indices`:
D400: First line should end with a period (not ')')
functional.py:1163 in public function `adaptive_max_pool3d_with_indices`:
D400: First line should end with a period (not ')')
functional.py:1220 in public function `adaptive_avg_pool2d`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:1220 in public function `adaptive_avg_pool2d`:
D400: First line should end with a period (not 'f')
functional.py:1220 in public function `adaptive_avg_pool2d`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1237 in public function `adaptive_avg_pool3d`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:1237 in public function `adaptive_avg_pool3d`:
D400: First line should end with a period (not 'f')
functional.py:1237 in public function `adaptive_avg_pool3d`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1255 in public function `dropout`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:1255 in public function `dropout`:
D400: First line should end with a period (not 't')
functional.py:1275 in public function `alpha_dropout`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1287 in public function `dropout1d`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:1287 in public function `dropout1d`:
D400: First line should end with a period (not ',')
functional.py:1325 in public function `dropout2d`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:1325 in public function `dropout2d`:
D400: First line should end with a period (not ',')
functional.py:1369 in public function `dropout3d`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:1369 in public function `dropout3d`:
D400: First line should end with a period (not ',')
functional.py:1408 in public function `feature_alpha_dropout`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:1408 in public function `feature_alpha_dropout`:
D400: First line should end with a period (not ',')
functional.py:1466 in public function `relu`:
D400: First line should end with a period (not 'r')
functional.py:1466 in public function `relu`:
D402: First line should not be the function's "signature"
functional.py:1491 in public function `glu`:
D400: First line should end with a period (not 'r')
functional.py:1491 in public function `glu`:
D402: First line should not be the function's "signature"
functional.py:1516 in public function `hardtanh`:
D400: First line should end with a period (not 'r')
functional.py:1516 in public function `hardtanh`:
D402: First line should not be the function's "signature"
functional.py:1542 in public function `relu6`:
D400: First line should end with a period (not 'r')
functional.py:1542 in public function `relu6`:
D402: First line should not be the function's "signature"
functional.py:1558 in public function `elu`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1582 in public function `selu`:
D400: First line should end with a period (not 'r')
functional.py:1582 in public function `selu`:
D402: First line should not be the function's "signature"
functional.py:1611 in public function `celu`:
D400: First line should end with a period (not 'r')
functional.py:1611 in public function `celu`:
D402: First line should not be the function's "signature"
functional.py:1638 in public function `leaky_relu`:
D400: First line should end with a period (not 'r')
functional.py:1638 in public function `leaky_relu`:
D402: First line should not be the function's "signature"
functional.py:1688 in public function `rrelu`:
D400: First line should end with a period (not 'r')
functional.py:1688 in public function `rrelu`:
D402: First line should not be the function's "signature"
functional.py:1755 in public function `tanhshrink`:
D400: First line should end with a period (not 'r')
functional.py:1755 in public function `tanhshrink`:
D402: First line should not be the function's "signature"
functional.py:1767 in public function `softsign`:
D400: First line should end with a period (not 'r')
functional.py:1767 in public function `softsign`:
D402: First line should not be the function's "signature"
functional.py:1806 in public function `softmin`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1832 in public function `softmax`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1868 in public function `gumbel_softmax`:
D401: First line should be in imperative mood (perhaps 'Sample', not 'Samples')
functional.py:1930 in public function `log_softmax`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:1969 in public function `tanh`:
D400: First line should end with a period (not 'r')
functional.py:1969 in public function `tanh`:
D402: First line should not be the function's "signature"
functional.py:1980 in public function `sigmoid`:
D400: First line should end with a period (not 'r')
functional.py:1980 in public function `sigmoid`:
D402: First line should not be the function's "signature"
functional.py:1990 in public function `hardsigmoid`:
D400: First line should end with a period (not 'n')
functional.py:1990 in public function `hardsigmoid`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2057 in public function `silu`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:2057 in public function `silu`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2081 in public function `mish`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:2081 in public function `mish`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2100 in public function `hardswish`:
D400: First line should end with a period (not ':')
functional.py:2100 in public function `hardswish`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2136 in public function `embedding`:
D202: No blank lines allowed after function docstring (found 1)
functional.py:2136 in public function `embedding`:
D401: First line should be in imperative mood; try rephrasing (found 'A')
functional.py:2254 in public function `embedding_bag`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:2254 in public function `embedding_bag`:
D400: First line should end with a period (not 'e')
functional.py:2254 in public function `embedding_bag`:
D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes')
functional.py:2462 in public function `batch_norm`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2507 in public function `instance_norm`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:2507 in public function `instance_norm`:
D400: First line should end with a period (not 'a')
functional.py:2507 in public function `instance_norm`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2540 in public function `layer_norm`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2554 in public function `group_norm`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2567 in public function `local_response_norm`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:2567 in public function `local_response_norm`:
D400: First line should end with a period (not 'f')
functional.py:2567 in public function `local_response_norm`:
D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
functional.py:2611 in public function `ctc_loss`:
D401: First line should be in imperative mood; try rephrasing (found 'The')
functional.py:2679 in public function `nll_loss`:
D401: First line should be in imperative mood; try rephrasing (found 'The')
functional.py:2895 in public function `kl_div`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:2895 in public function `kl_div`:
D400: First line should end with a period (not 's')
functional.py:2895 in public function `kl_div`:
D401: First line should be in imperative mood; try rephrasing (found 'The')
functional.py:2978 in public function `cross_entropy`:
D401: First line should be in imperative mood; try rephrasing (found 'This')
functional.py:3069 in public function `binary_cross_entropy`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:3069 in public function `binary_cross_entropy`:
D400: First line should end with a period (not 't')
functional.py:3069 in public function `binary_cross_entropy`:
D401: First line should be in imperative mood; try rephrasing (found 'Function')
functional.py:3139 in public function `binary_cross_entropy_with_logits`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:3139 in public function `binary_cross_entropy_with_logits`:
D400: First line should end with a period (not 't')
functional.py:3139 in public function `binary_cross_entropy_with_logits`:
D401: First line should be in imperative mood; try rephrasing (found 'Function')
functional.py:3211 in public function `smooth_l1_loss`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:3211 in public function `smooth_l1_loss`:
D400: First line should end with a period (not 'e')
functional.py:3211 in public function `smooth_l1_loss`:
D401: First line should be in imperative mood; try rephrasing (found 'Function')
functional.py:3251 in public function `huber_loss`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:3251 in public function `huber_loss`:
D400: First line should end with a period (not 'e')
functional.py:3251 in public function `huber_loss`:
D401: First line should be in imperative mood; try rephrasing (found 'Function')
functional.py:3282 in public function `l1_loss`:
D400: First line should end with a period (not 'r')
functional.py:3282 in public function `l1_loss`:
D402: First line should not be the function's "signature"
functional.py:3313 in public function `mse_loss`:
D400: First line should end with a period (not 'r')
functional.py:3313 in public function `mse_loss`:
D402: First line should not be the function's "signature"
functional.py:3346 in public function `margin_ranking_loss`:
D400: First line should end with a period (not 'r')
functional.py:3346 in public function `margin_ranking_loss`:
D402: First line should not be the function's "signature"
functional.py:3382 in public function `hinge_embedding_loss`:
D400: First line should end with a period (not 'r')
functional.py:3382 in public function `hinge_embedding_loss`:
D402: First line should not be the function's "signature"
functional.py:3411 in public function `multilabel_margin_loss`:
D400: First line should end with a period (not 'r')
functional.py:3411 in public function `multilabel_margin_loss`:
D402: First line should not be the function's "signature"
functional.py:3439 in public function `soft_margin_loss`:
D400: First line should end with a period (not 'r')
functional.py:3439 in public function `soft_margin_loss`:
D402: First line should not be the function's "signature"
functional.py:3462 in public function `multilabel_soft_margin_loss`:
D400: First line should end with a period (not 'r')
functional.py:3462 in public function `multilabel_soft_margin_loss`:
D402: First line should not be the function's "signature"
functional.py:3510 in public function `cosine_embedding_loss`:
D400: First line should end with a period (not 'r')
functional.py:3510 in public function `cosine_embedding_loss`:
D402: First line should not be the function's "signature"
functional.py:3543 in public function `multi_margin_loss`:
D400: First line should end with a period (not 'r')
functional.py:3543 in public function `multi_margin_loss`:
D402: First line should not be the function's "signature"
functional.py:3708 in public function `upsample` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3713 in public function `upsample` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3718 in public function `upsample` (skipping F811):
D205: 1 blank line required between summary line and description (found 0)
functional.py:3718 in public function `upsample` (skipping F811):
D400: First line should end with a period (not 'n')
functional.py:3783 in private function `_is_integer`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:3794 in public function `interpolate` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3799 in public function `interpolate` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3804 in public function `interpolate` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3809 in public function `interpolate` (skipping F811):
D103: Missing docstring in public function
functional.py:3821 in public function `interpolate` (skipping F811,B950):
D205: 1 blank line required between summary line and description (found 0)
functional.py:3821 in public function `interpolate` (skipping F811,B950):
D400: First line should end with a period (not 'n')
functional.py:4062 in public function `upsample_nearest` (skipping F811):
D103: Missing docstring in public function
functional.py:4067 in public function `upsample_nearest` (skipping F811):
D103: Missing docstring in public function
functional.py:4100 in public function `upsample_bilinear` (skipping F811):
D103: Missing docstring in public function
functional.py:4107 in public function `upsample_bilinear` (skipping F811):
D103: Missing docstring in public function
functional.py:4114 in public function `upsample_bilinear` (skipping F811):
D103: Missing docstring in public function
functional.py:4121 in public function `upsample_bilinear` (skipping F811):
D103: Missing docstring in public function
functional.py:4174 in public function `grid_sample`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:4174 in public function `grid_sample`:
D400: First line should end with a period (not 'e')
functional.py:4315 in public function `affine_grid`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:4315 in public function `affine_grid`:
D400: First line should end with a period (not 'f')
functional.py:4315 in public function `affine_grid`:
D401: First line should be in imperative mood (perhaps 'Generate', not 'Generates')
functional.py:4608 in public function `triplet_margin_loss`:
D200: One-line docstring should fit on one line with quotes (found 3)
functional.py:4608 in public function `triplet_margin_loss`:
D400: First line should end with a period (not 's')
functional.py:4643 in public function `triplet_margin_with_distance_loss`:
D200: One-line docstring should fit on one line with quotes (found 3)
functional.py:4705 in public function `normalize`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
functional.py:4733 in public function `assert_int_or_pair`:
D103: Missing docstring in public function
functional.py:4743 in public function `unfold`:
D401: First line should be in imperative mood (perhaps 'Extract', not 'Extracts')
functional.py:4773 in public function `fold`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:4773 in public function `fold`:
D400: First line should end with a period (not 'g')
functional.py:4773 in public function `fold`:
D401: First line should be in imperative mood (perhaps 'Combine', not 'Combines')
functional.py:4800 in private function `_in_projection_packed`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:4800 in private function `_in_projection_packed`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
functional.py:4867 in private function `_in_projection`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:4867 in private function `_in_projection`:
D400: First line should end with a period (not 'y')
functional.py:4867 in private function `_in_projection`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
functional.py:5128 in public function `multi_head_attention_forward`:
D205: 1 blank line required between summary line and description (found 0)
functional.py:5128 in public function `multi_head_attention_forward`:
D400: First line should end with a period (not ':')
160
```
**AFTER:**
```
functional.py:3709 in public function `upsample` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3714 in public function `upsample` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3798 in public function `interpolate` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3803 in public function `interpolate` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3808 in public function `interpolate` (skipping F811,B950):
D103: Missing docstring in public function
functional.py:3813 in public function `interpolate` (skipping F811):
D103: Missing docstring in public function
functional.py:4068 in public function `upsample_nearest` (skipping F811):
D103: Missing docstring in public function
functional.py:4073 in public function `upsample_nearest` (skipping F811):
D103: Missing docstring in public function
functional.py:4106 in public function `upsample_bilinear` (skipping F811):
D103: Missing docstring in public function
functional.py:4113 in public function `upsample_bilinear` (skipping F811):
D103: Missing docstring in public function
functional.py:4120 in public function `upsample_bilinear` (skipping F811):
D103: Missing docstring in public function
functional.py:4127 in public function `upsample_bilinear` (skipping F811):
D103: Missing docstring in public function
functional.py:4742 in public function `assert_int_or_pair`:
D103: Missing docstring in public function
13
```
The file contained several docstring errors. I have fixed all of them(hopefully) and have tried to improve the over all readability of the code. For most part, I have included relevant description of functions (referred from official PyTorch Docs). In some cases where functions are purely mathematical or it is difficult to give one line description, I have just included references.
For testing, I relied on local system and created a separate file. For final edits, I directly changed the contents of forked repo as visible already.
Kindly review @svekars @subramen @kit1980
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112856
Approved by: https://github.com/kit1980
In https://github.com/pytorch/pytorch/pull/99243, a check was added to ensure the `size` only contained integers.
This PR updates the check to also include numpy integers based on this comment (cc @kit1980): https://github.com/pytorch/pytorch/pull/99243#issuecomment-1646736646. Similar to the other commenter, I also ran into issues where existing software broke due to this after upgrading to PT2.1:
```
if not torch.jit.is_scripting():
if not all(_is_integer(x) for x in size):
> raise TypeError(
"expected size to be one of int or Tuple[int] or Tuple[int, int] or "
f"Tuple[int, int, int], but got size with types {[type(x) for x in size]}"
)
E TypeError: expected size to be one of int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], but got size with types [<class 'numpy.int64'>, <class 'numpy.int64'>]
/conda-env/lib/python3.8/site-packages/torch/nn/functional.py:3924: TypeError
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110778
Approved by: https://github.com/mikaylagawarecki
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
Fixes#99148 , raising an error if output_ratio's size > 2.
Justification for changes:
If an output size is not specified but an output ratio is, we call fractional_max_pool2d_with_indices. We then generate the value of output_size based on the first two integers of the output_ratio (line ~480 of torch.nn.functional.py).
Thus, we should raise a value error in the case that the user passes an output_ratio (instead of an output_size) and the number of elements in output_ratio exceeds two. We must raise an error before calling torch._C._nn.franctional_max_pool2d as the value of output_size passed into torch._C._nn.fractional_max_pool2d is guaranteed to be of size 2 (as the existing code generates it from the first two indices of the passed in ratio).
I would be happy to iterate on this if there are any issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99507
Approved by: https://github.com/mikaylagawarecki
Summary:
This fixes an issue raised in [is_causal parameter in torch.nn.TransformerEncoderLayer.forward does not work #96941](https://github.com/pytorch/pytorch/issues/96941) where results computed with is_causal do not properly reflect causal masking.
In PyTorch 2.0, Accelerated PT Transformers added the is_causal parameter to legacy nn.Transformer* and nn.MHA APIs aligned with and intended to engage the is_causal parameter of the new scaled_dot_product_attention (SDPA) operator.
At present is_causal works differently for Transformer* modules, the nn.MHA and F.MHA:
* The nn.Transformer* modules treat is_causal as an optional indicator about the format of attn_mask. This is because some layers (such as the CLIP layer use the attention mask in the layer, and thus the attn_mask was a required feature.)
* Initially, nn.MHA and F.MHA were defined to align with F.SDPA in behavior: a user may specify either the attention mask, or is_causal, but not both. It seemed to make sense at the time to align SDPA and MHA, esp since there was a larger overlap of parameters which have since changed, e.g., with the removal of need_weights from SDPA. (See below for why this makes sense.)
Unfortunately, this does not work because of how MHA was changed to handle the need_weights parameter. When need_weights is present, we do not (any more) call SDPA because support for need_weights was removed from SDPA before the release. The rationale is that need_weights defeats all optimization at the foundation of SDPA performance. Having the flag might thus mislead users into thinking they get good performance and have them disappointed when they enable a legacy feature of MHA which massively degrades performance. (They might not think anything of enabling that, because it is on by default in MHA today, which leads to more issues.)
Since SDPA does not (no longer) support need_weights, we need to pick a separate path which implements attention using a set of discrete operations that allocates a tensor for weights. Alas, this code path does not have support for is_causal, because attention is implemented as matmul and using the attention mask. Thus, is_causal has no impact. (A substantially similar situation arises with how kpm is implemented today because Nested Tensors are not supported by torch.compile() in 2.0)
This problem was masked because all uses of legacy nn.MHA (and F.MHA) come through nn.Transformer* which called self-attention (i.e., nn.MHA) only ever with the attention mask attn_mask, and never with is_causal, a missed optimization opportunit that would have been addressed in a future performance update.
Regrettably, always calling nn.MHA with attn_mask prevented diagnosing of the issue of not having a suitable attention mask when need_weights support was dropped from SDPA and a discrete implementation of attention was added for that scenario, and for the execution path with key_padding_mask.
We have two options to address this issue:
Solution 1: Whenever nn.MHA and F.MHA are executed with is_causal set, we internally create a causal mask at significant expense of allocating a tensor and filling it with a triangular causal matrix. This increases memory usage, and runtime, for allocating a causal mask. To add insult to injury, in all current (and likely future) execution scenarios, MHA is called by a model using the nn.Transformer API which already has that matrix and passes it from nn.module to nn.module. Then the passing in of attn_mask has to be suppressed by nn.TransformerEncoderLayer, only for nn.MHA to immediately allocate the very same tensor again to satisfy the requirement to have an attention mask for the computation. (We expect new use cases to use SDPA directly.)
Solution 2: We align the behavior of nn.MHA and F.MHA with the rest of the existing nn.Transformer API, and require the attention mask to be passed into nn.MHA in addition to is_causal as an optional indicator about the nature of the attention mask rather than as an alternative to attn_mask. Then, when we choose the code path for processing MHA with need_weights or a key_padding_mask, we have the attn_mask passed down through the nn.Transformer* hierarchy, without the added overhead of allocating an attention mask as in scenario 1.
This PR implements solution 2 which offers better performance and in retrospect aligns MHA better with the rest of the Transformer modules as the definition of SDPA evolved into a more streamlined high-performance operator. It ostensibly changes how is_causal works, by requiring the attention mask to be specified. However, as described here, and as shown in the submitted issue, is_causal is not working as intended today, so it requires a change regardless.
In that sense, a change in API does not occur per-se, as the current implementation is not working, and a change has to occur either way to resolve the submitted issue, breaking any use cases that depend on the current implementation. Checks exist (and more can be added) that flag any scenarios where is_causal is passed as True, but no attention mask is provided, ensuring that there's not quiet change from even the faulty behavior present in 2.0.
As an upside, the present implementation will improve performance by addressing the passing of the is_causal flag from Transformer modules to MHA, speeding up training for these examples, e.g., finetuning BERT, RoBERTa, XLM-R models.
Differential Revision: D44245725
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97214
Approved by: https://github.com/albanD
Fixes#96813.
Comments:
1. Wasn't able to test since tools/nightly.py does not allow for GPU build (and I don't want to build from scratch).
2. In theory, the bug (i.e. NaNs) can still occur when beta is very small (e.g. `beta=1e-50`), but not sure whether anybody cares.
3. Some checks within the smooth_l1_loss C++ code could be changed to check for `beta > 0` instead of `beta >= 0`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97022
Approved by: https://github.com/jbschlosser
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the q@k.T step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.
Will reduce the complexity of: #94729 and has been asked for by a couple of users.
# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `Q@K.T * (scale)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95259
Approved by: https://github.com/cpuhrsch
Summary: fix src and pad mask bool regression
This fixes a regression introduced previously with #92733. That PR unified testing of masks to remove Byte Tensors as permissible mask, introduced mask compatibility check, and mask conversion to FP mask. The problem addressed in this PR was that after the first mask had been converted, a check for mask compatibility would fail.
Test Plan: sandcastle & github
Differential Revision: D43782858
Fixes https://github.com/pytorch/pytorch/issues/95702
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96009
Approved by: https://github.com/malfet
# Summary
- Adds type hinting support for SDPA
- Updates the documentation adding warnings and notes on the context manager
- Adds scaled_dot_product_attention to the non-linear activation function section of nn.functional docs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94008
Approved by: https://github.com/cpuhrsch
Summary:
Regularize mask handling for attn_mask and key_padding_mask
* Update documentation to remove reference to byte masks (which were deprecated long ago)
* Introduce check and warn about deprecation if attn_mask and key_padding_mask types mismatch
* Convert all masks to float before combining
* Combine by adding
Test Plan: sandcastle & github CI
Differential Revision: D42653215
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92733
Approved by: https://github.com/ngimel, https://github.com/drisspg
# Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.
## Changes
### API
Previously the the function signature was:
`scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)`
Updated signature:
`scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor`
This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.
#### Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.
The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.
Discussed with folks at FAIR/Xformers and +1 this API change.
#### Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189
Approved by: https://github.com/cpuhrsch
Summary: Introduce causal mask
This PR introduces a causal mask option _causal_mask (as well as causal mask detection if attn_mask is provided), since current custom kernels do not support arbitrary masks.
Test Plan: sandcastle & github ci/cd
Differential Revision: D41723137
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90508
Approved by: https://github.com/albanD
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
I figured these out by unconditionally turning on a no-op torch function
mode on the test suite and then fixing errors as they showed up. Here's
what I found:
- _parse_to failed internal assert when __torch_function__'ed because it
claims its name is "to" to the argument parser; added a name override
so we know how to find the correct name
- Infix operator magic methods on Tensor did not uniformly handle
__torch_function__ and TypeError to NotImplemented. Now, we always
do the __torch_function__ handling in
_wrap_type_error_to_not_implemented and your implementation of
__torch_function__ gets its TypeErrors converted to NotImplemented
(for better or for worse; see
https://github.com/pytorch/pytorch/issues/75462 )
- A few cases where code was incorrectly testing if a Tensor was
Tensor-like in the wrong way, now use is_tensor_like (in grad
and in distributions). Also update docs for has_torch_function to
push people to use is_tensor_like.
- is_grads_batched was dropped from grad in handle_torch_function, now
fixed
- Report that you have a torch function even if torch function is
disabled if a mode is enabled. This makes it possible for a mode
to return NotImplemented, pass to a subclass which does some
processing and then pass back to the mode even after the subclass
disables __torch_function__ (so the tensors are treated "as if"
they are regular Tensors). This brings the C++ handling behavior
in line with the Python behavior.
- Make the Python implementation of overloaded types computation match
the C++ version: when torch function is disabled, there are no
overloaded types (because they all report they are not overloaded).
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75484
Approved by: https://github.com/zou3519
Closes#44459
This migrates the python implementation of `_pad_circular` to ATen and
removes the old C++ implementation that had diverged from python.
Note that `pad` can't actually use this until the
forward-compatibility period is over.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73410
Approved by: https://github.com/ezyang
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72871
We do this same trick in the native MHA implementation; backport it for purposes of fair comparison.
ghstack-source-id: 149526858
Test Plan: CI
Reviewed By: ngimel
Differential Revision: D34176090
fbshipit-source-id: 8b578c29c4dcf0d85bae74dfbbb82db9a8f32dc7
(cherry picked from commit fd50170935)