Summary: Decoder native joins the dead code society
With the recent introduction of PT2, we no longer need native decoder operators:
1 - full-function SDPA kernels can be used to implement cross-attention efficiently without the (slower) decoder MHA blob.
2 - torch.compile() generates more efficient code across many platforms from the python implementation of decoders than the decoder layer blob by tailoring code to target
Test Plan: github & sandcastle
Differential Revision: D43811808
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96025
Approved by: https://github.com/ezyang, https://github.com/albanD
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the q@k.T step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.
Will reduce the complexity of: #94729 and has been asked for by a couple of users.
# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `Q@K.T * (scale)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95259
Approved by: https://github.com/cpuhrsch
Summary: fix src and pad mask bool regression
This fixes a regression introduced previously with #92733. That PR unified testing of masks to remove Byte Tensors as permissible mask, introduced mask compatibility check, and mask conversion to FP mask. The problem addressed in this PR was that after the first mask had been converted, a check for mask compatibility would fail.
Test Plan: sandcastle & github
Differential Revision: D43782858
Fixes https://github.com/pytorch/pytorch/issues/95702
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96009
Approved by: https://github.com/malfet
# Summary
Previously, for NestedTensor inputs flash_attention was disabled due to an Illegal Memory Access error that was occurring on the "cutlass" branch of flash-attention that had be incorporated into core. Since we have switched to the main branch of flash_attention we the existing repro script did not produce the same memory error. This PR re-enables the FlashAttention Path for NTs. As well it unifies the nested preprocessing between the two implementations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95438
Approved by: https://github.com/mikaylagawarecki
# Summary
Add more checks around shape constraints as well as update the sdp_utils to properly catch different head_dims between qk and v for flash_attention which is not supported.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94274
Approved by: https://github.com/cpuhrsch
# Summary
- Adds a large parameter sweep for testing the various configs a user can call sdpa with and compares the deviation of the fused kernels vs the eager math fallback to test for correctness.
- Sm86 + head_dim==128 is throwing an IMA for memory efficient attention. We add a filter for use_mem_efficient_attention(). This has since been fixed in the upstream Xformers version but will likely not make it for branch cut.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94009
Approved by: https://github.com/cpuhrsch
# Summary
This PR creates _flash_attention_backward and _scaled_dot_product_flash_attention_backward native functions and registers them to the respective derivatives.yaml.
The goal is to replicate the torch.autograd.Function defined in the FlashAttention repo [here](33e0860c9c/flash_attn/flash_attn_interface.py (L126)) natively in PyTorch. One thing that we don't have access to is ctx.save_for_backward in native PyTorch so in order to save these variables I extended the returned objects from the forward functions.
### MetaFunctions
I also updated the FlashAttention meta functions to mirror the real outputs now. As well I added a meta registration for backwards. I have an XLMR training script and while eager training now works with FlashAttention compiling this module fails with the inductor error down below.
### Questions?
Performance issues vs mem efficient when using torch.nn.mha_forward
TorchCompile -> See purposed solution below.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92917
Approved by: https://github.com/cpuhrsch
# Summary
Add support for fused attention kernels (FlashAttention and memory-efficient attention) on Windows. Previously we could not do this because the fixes required c++17 to do this but we have since update the PyTorch standard.
This PR:
- Changes invocations of unsigned long to the fixed width integer type
- Adds in the #define FP16_SWITCH(COND, ...) which has been added to the flash_attention main branch
- Changes the some macros used within mem-efficient attention code in order to work around the VA_ARG discrepancy between clang/gcc and msvc. An alternative would be setting the global flag Zc:preprocessor
- Selectively applies /Zc:lambda to only the mem-efficient sources since applying this globally caused quantization files to not compile
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91909
Approved by: https://github.com/cpuhrsch
Summary:
Regularize mask handling for attn_mask and key_padding_mask
* Update documentation to remove reference to byte masks (which were deprecated long ago)
* Introduce check and warn about deprecation if attn_mask and key_padding_mask types mismatch
* Convert all masks to float before combining
* Combine by adding
Test Plan: sandcastle & github CI
Differential Revision: D42653215
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92733
Approved by: https://github.com/ngimel, https://github.com/drisspg
# Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.
## Changes
### API
Previously the the function signature was:
`scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)`
Updated signature:
`scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor`
This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.
#### Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.
The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.
Discussed with folks at FAIR/Xformers and +1 this API change.
#### Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189
Approved by: https://github.com/cpuhrsch
# Summary
Memory efficient attention is a non deterministic algorithm.
This PR ensures that the sdp_choice will allow for mem-efficient to be used as the backend to SDPA if we are in warn only mode. Otherwise if we have enabled determinism and and set warn_only to False sdp_choice will not return memory efficient attention as the backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91979
Approved by: https://github.com/cpuhrsch
# Summary
This PR updates the second return value from SDPA to return an empty tensor of size 0 not what it would be if need_attn_weights is True. Also updates the meta function to account for this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91782
Approved by: https://github.com/cpuhrsch
# Summary
Creates a callable native function that can determine which implementation of scaled dot product will get called. This allows to bump re-order the runtime dispatch of SDP to enable autograd.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89029
Approved by: https://github.com/cpuhrsch
# Registers the derivative for mem efficient backward
- Use gradcheck to test correctness. The kernel is not implemented for fp64 so run checks with bumped tolerances in fp32
- I also made updates based off of Xformer main branch and flash-attention cutlass branch.
- This will enable the fused backward to be called for scaled dot product attention
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88856
Approved by: https://github.com/cpuhrsch
# Registers the derivative for mem efficient backward
- Use gradcheck to test correctness. The kernel is not implemented for fp64 so run checks with bumped tolerances in fp32
- I also made updates based off of Xformer main branch and flash-attention cutlass branch.
- This will enable the fused backward to be called for scaled dot product attention
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88856
Approved by: https://github.com/cpuhrsch
Fixes T135842750 (follow-up for #87377)
## Description
At present, having both `src_key_padding_mask` and `src_mask` at the same time is not supported on the fastpath in Transformer and Multi-Head Attention.
This PR enables using both masks on the fastpath on CPU and GPU: if both masks are passed, we merge them into a 4D mask in Python and change mask type to 2 before passing downstream.
Downstream processing in native code is not changed, as it already supports 4D mask. Indeed, it is done depending on the device:
- on CUDA, by `SoftMax.cu::masked_softmax_cuda`. When mask type is 2, it calls either `dispatch_softmax_forward` -> `softmax_warp_forward` or `at::softmax` (depending on the input size). In both cases 4D mask is supported.
- on CPU, by `SoftMax.cpp::masked_softmax_cpp`. It calls `hosted_softmax` which supports 4D mask.
## Tests
- Extended `test_mask_check_fastpath` to check that fast path is indeed taken in Transformer when two masks are passed
- Added `test_multihead_self_attn_two_masks_fast_path_mock` to check that fast path is taken in MHA when two masks are passed
- Added `test_multihead_self_attn_two_masks_fast_path` to check that fast and slow paths give the same result when two masks are passed in MHA
- `test_masked_softmax_mask_types` now covers mask type 2
- `test_transformerencoderlayer_fast_path` (CPU smoke test) is expanded to the case of both masks provided simultaneously
- `test_masked_softmax_devices_parity` checks that mask type 2 is accepted by CPU and CUDA paths
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88488
Approved by: https://github.com/mikekgfb
## Issues
Fixes https://github.com/pytorch/pytorch/issues/81129#issuecomment-1179435674
## Description
Passing a 2D attention mask `src_mask` into the fast path of `TransformerEncoderLayer` in CPU was causing an error and so was disabled in https://github.com/pytorch/pytorch/pull/81277. This PR unrolls this fix, enabling `src_mask` on the fast path:
- Either attention mask `src_mask` of shape `(L, L)` or padding mask `src_key_padding_mask` of shape `(B, L)` are now allowed on the CPU fast path. If softmax is applied along the last dimension (as in multi-head attention), these masks are processed without expanding them to 4D. Instead, when iterating through the input, `Softmax.cpp::host_softmax` converts the index to match the mask dimensions, depending on the type.
- If softmax is applied along the dimension other than the last, `Softmax.cpp::masked_softmax_cpu` expands masks to 4D, converting them to `mask_type=2`. Theoretically one could also add special optimized cases for `dim=0, 1, 2` and process them without mask expansion, but I don't know how often is that used
## Tests:
- `test_transformerencoderlayer_fast_path` is extended to cover both attention mask and padding mask
- `test_masked_softmax_mask_types_0_1` is added to ensure results from CPU softmax with attention and padding masks match the explicit slow calculation
- `test_masked_softmax_devices_parity` is added to ensure results from masked softmax on CPU and CUDA match
## Note
I had to replace `float` with `torch.get_default_dtype()` in a couple of tests for the following reason:
- `test_nn.py` [sets the default type to `torch.double`](https://github.com/pytorch/pytorch/blob/master/test/test_nn.py#L24-L26)
- If I execute `test_nn.py` and `test_transformers.py` in one `pytest` run, this default still holds for transformer tests
- Some tests in `test_transformers.py` which were previously following the slow path now switched to fast path, and hard-coded `float` started clashing with default `double`
Let me know if there is a better way around it - or maybe I'm not supposed to run tests with `pytest` like this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87377
Approved by: https://github.com/mikekgfb, https://github.com/weiwangmeta, https://github.com/malfet
# Summary
Use the private _scaled_dot_product_attention to support _native_multiheaded_attention. _SDP provides access to fused kernels when certain conditions are meant enabling a speed up for MHA.
cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87312
Approved by: https://github.com/cpuhrsch
# Summary
Add in a torch.backends.cuda flag and update context manager to pic between the three implementations of the scaled_dot_product_attention.
cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87946
Approved by: https://github.com/cpuhrsch
# Summary
- This code creates the runtime dispatch system for choosing a performant fused SDP kernel. The only choice of fused kernel is flash_attention. It also creates python flags and a context manager that can be used to turn off and on behavior for dispatch.
- This also adds support for flash_attention with dense tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85984
Approved by: https://github.com/cpuhrsch
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
# Summary
This exposes the _scaled_dot_product_attention function to python in the nn namespace. It is still underscored because the api for args, and kwargs is still in flux for the next few weeks and will eventually land as a prototype feature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85044
Approved by: https://github.com/cpuhrsch
Summary: Check that fastpath is taken, which type (sparsity fastpath or normal) for mask that is aligned and one that is not.
Test Plan: buck test caffe2/test:test_transformers
Differential Revision: D38259928
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82999
Approved by: https://github.com/jbschlosser
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81956
Approved by: https://github.com/drisspg, https://github.com/erichan1
Adds an initial private API version of the SDP interface.
Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```
Returns a tuple of `(output, attn_weights)`.
Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81956
Approved by: https://github.com/drisspg, https://github.com/erichan1
Summary:
Add test just to check if TransformerEncoder will crash when enumerating over params [with_no_grad, use_torchscript, training].
Motivation for this was that TransformerEncoder fast path (so with_no_grad=True) and use_torchscript=True would crash with the issue that NestedTensor doesn't have size. This was caused because the TransformerEncoder fast path generates a NestedTensor automatically as a perf optimization and torchscript attempts to find intermediate tensor sizes while it optimizes. But NestedTensor has not implemented a size method, so things fail.
This test goes together with this fix https://github.com/pytorch/pytorch/pull/79480
Test Plan:
```
buck build --show-output mode/opt -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=a100 mode/inplace //caffe2/test:transformers
./fbcode/buck-out/gen/caffe2/test/transformers#binary.par
```
Test runs and passes together with the changes from the PR above (I made another diff on top of this with those changes). Does not pass without the fix.
Reviewed By: mikekgfb
Differential Revision: D37222923
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79796
Approved by: https://github.com/zrphercule